+ {% if error is defined %}
+
Error: {{ error }}
+ {% endif %}
Please click the button below if you agree to the
privacy policy of this homeserver.
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a62..73284e48ec 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -16,7 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string
@@ -95,29 +95,32 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.recaptcha_template.render(
session=session,
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
sitekey=self.hs.config.recaptcha_public_key,
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
- success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, request.getClientIP()
- )
-
- if success:
- html = self.success_template.render()
- else:
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ # Authentication failed, let user try again
html = self.terms_template.render(
session=session,
terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +130,16 @@ class AuthRestServlet(RestServlet):
),
myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
+ error=e.msg,
)
+ else:
+ # No LoginError was raised, so authentication was successful
+ html = self.success_template.render()
+
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/static/client/register/style.css b/synapse/static/client/register/style.css
index 5a7b6eebf2..8a39b5d0f5 100644
--- a/synapse/static/client/register/style.css
+++ b/synapse/static/client/register/style.css
@@ -57,4 +57,8 @@ textarea, input {
background-color: #f8f8f8;
border: 1px #ccc solid;
-}
\ No newline at end of file
+}
+
+.error {
+ color: red;
+}
--
cgit 1.5.1
From bec01c075829730cf467572e2fcf93e15372b0e9 Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Wed, 18 Aug 2021 09:22:07 -0400
Subject: Convert room member storage tuples to attrs. (#10629)
Instead of using namedtuples. This helps with asserting type hints
and code completion.
---
changelog.d/10629.misc | 1 +
synapse/handlers/initial_sync.py | 2 +-
synapse/handlers/sync.py | 18 +++++-----
synapse/storage/databases/main/roommember.py | 8 +++--
synapse/storage/databases/main/user_directory.py | 2 +-
synapse/storage/roommember.py | 43 ++++++++++++++++--------
tests/replication/slave/storage/test_events.py | 9 +++--
7 files changed, 54 insertions(+), 29 deletions(-)
create mode 100644 changelog.d/10629.misc
(limited to 'synapse')
diff --git a/changelog.d/10629.misc b/changelog.d/10629.misc
new file mode 100644
index 0000000000..cca1eb6c57
--- /dev/null
+++ b/changelog.d/10629.misc
@@ -0,0 +1 @@
+Convert room member storage tuples to `attrs` classes.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e1c544a3c9..4e8f7f1d85 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
limit = 10
async def handle_room(event: RoomsForUser):
- d = {
+ d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index eba915819e..b7b299961f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -701,7 +701,7 @@ class SyncHandler:
name_id = state_ids.get((EventTypes.Name, ""))
canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
- summary = {}
+ summary: JsonDict = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
@@ -2076,21 +2076,23 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, event_pos in joined_rooms:
- if not event_pos.persisted_after(room_key):
- joined_room_ids.add(room_id)
+ for joined_room in joined_rooms:
+ if not joined_room.event_pos.persisted_after(room_key):
+ joined_room_ids.add(joined_room.room_id)
continue
- logger.info("User joined room after current token: %s", room_id)
+ logger.info("User joined room after current token: %s", joined_room.room_id)
extrems = (
await self.store.get_forward_extremities_for_room_at_stream_ordering(
- room_id, event_pos.stream
+ joined_room.room_id, joined_room.event_pos.stream
)
)
- users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
+ users_in_room = await self.state.get_current_users_in_room(
+ joined_room.room_id, extrems
+ )
if user_id in users_in_room:
- joined_room_ids.add(room_id)
+ joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3d4..c2f6b9d63d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+ async def get_invited_rooms_for_local_user(
+ self, user_id: str
+ ) -> List[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -522,7 +524,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
- async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+ async def get_rooms_for_user(
+ self, user_id: str, on_invalidate=None
+ ) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69ac7..65dde67ae9 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
async def update_profile_in_user_dir(
- self, user_id: str, display_name: str, avatar_url: str
+ self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
) -> None:
"""
Update or add a user's profile in the user directory.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c34fbf21bc..0ff66debdf 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,25 +14,40 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from typing import List, Optional, Tuple
+
+import attr
+
+from synapse.types import PersistedEventPosition
logger = logging.getLogger(__name__)
-RoomsForUser = namedtuple(
- "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+class RoomsForUser:
+ room_id: str
+ sender: str
+ membership: str
+ event_id: str
+ stream_ordering: int
+
+
+@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+class GetRoomsForUserWithStreamOrdering:
+ room_id: str
+ event_pos: PersistedEventPosition
-GetRoomsForUserWithStreamOrdering = namedtuple(
- "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+class ProfileInfo:
+ avatar_url: Optional[str]
+ display_name: Optional[str]
-# We store this using a namedtuple so that we save about 3x space over using a
-# dict.
-ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
-# "members" points to a truncated list of (user_id, event_id) tuples for users of
-# a given membership type, suitable for use in calculating heroes for a room.
-# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple("MemberSummary", ("members", "count"))
+@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+class MemberSummary:
+ # A truncated list of (user_id, event_id) tuples for users of a given
+ # membership type, suitable for use in calculating heroes for a room.
+ members: List[Tuple[str, str]]
+ # The total number of users of a given membership type.
+ count: int
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index db80a0bdbd..8d1b0606c4 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -216,7 +216,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, expected_pos)},
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -305,7 +305,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
expected_pos = PersistedEventPosition(
"master", j2.internal_metadata.stream_ordering
)
- self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
+ self.assertEqual(
+ joined_rooms,
+ {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+ )
event_id = 0
--
cgit 1.5.1
From d9856d9150c05577282cb52f548de7737bb1e454 Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Wed, 18 Aug 2021 11:00:37 -0400
Subject: Fix weakref_slot parameter for room member storage attrs. (#10642)
Follow-up to #10629 which set it to true, not false.
---
changelog.d/10642.misc | 1 +
synapse/storage/roommember.py | 8 ++++----
2 files changed, 5 insertions(+), 4 deletions(-)
create mode 100644 changelog.d/10642.misc
(limited to 'synapse')
diff --git a/changelog.d/10642.misc b/changelog.d/10642.misc
new file mode 100644
index 0000000000..cca1eb6c57
--- /dev/null
+++ b/changelog.d/10642.misc
@@ -0,0 +1 @@
+Convert room member storage tuples to `attrs` classes.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0ff66debdf..9fad67ce48 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -23,7 +23,7 @@ from synapse.types import PersistedEventPosition
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class RoomsForUser:
room_id: str
sender: str
@@ -32,19 +32,19 @@ class RoomsForUser:
stream_ordering: int
-@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class GetRoomsForUserWithStreamOrdering:
room_id: str
event_pos: PersistedEventPosition
-@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class ProfileInfo:
avatar_url: Optional[str]
display_name: Optional[str]
-@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
class MemberSummary:
# A truncated list of (user_id, event_id) tuples for users of a given
# membership type, suitable for use in calculating heroes for a room.
--
cgit 1.5.1
From 0c3565da4cdbe53646ae0bc737900526a1d3df67 Mon Sep 17 00:00:00 2001
From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com>
Date: Wed, 18 Aug 2021 19:53:20 +0200
Subject: Additional type hints for the proxy agent and SRV resolver modules.
(#10608)
---
changelog.d/10608.misc | 1 +
mypy.ini | 3 +++
synapse/http/additional_resource.py | 13 +++++++---
synapse/http/federation/srv_resolver.py | 45 ++++++++++++++++++---------------
synapse/http/proxyagent.py | 4 +--
5 files changed, 41 insertions(+), 25 deletions(-)
create mode 100644 changelog.d/10608.misc
(limited to 'synapse')
diff --git a/changelog.d/10608.misc b/changelog.d/10608.misc
new file mode 100644
index 0000000000..875bdd2fd0
--- /dev/null
+++ b/changelog.d/10608.misc
@@ -0,0 +1 @@
+Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index e1b9405daa..107f4de76c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,10 +28,13 @@ files =
synapse/federation,
synapse/groups,
synapse/handlers,
+ synapse/http/additional_resource.py,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
+ synapse/http/federation/srv_resolver.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
+ synapse/http/proxyagent.py,
synapse/http/servlet.py,
synapse/http/server.py,
synapse/http/site.py,
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 55ea97a07f..9a2684aca4 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeJsonResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
@@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""
- def __init__(self, hs, handler):
+ def __init__(self, hs: "HomeServer", handler):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
@@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
``request.write()``, and call ``request.finish()``.
Args:
- hs (synapse.server.HomeServer): homeserver
+ hs: homeserver
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
super().__init__()
self._handler = handler
- def _async_render(self, request):
+ def _async_render(self, request: Request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b8ed4ec905..f68646fd0d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
import logging
import random
import time
-from typing import List
+from typing import Callable, Dict, List
import attr
@@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
-SERVER_CACHE = {}
+SERVER_CACHE: Dict[bytes, List["Server"]] = {}
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
class Server:
"""
Our record of an individual server which can be tried to reach a destination.
Attributes:
- host (bytes): target hostname
- port (int):
- priority (int):
- weight (int):
- expires (int): when the cache should expire this record - in *seconds* since
+ host: target hostname
+ port:
+ priority:
+ weight:
+ expires: when the cache should expire this record - in *seconds* since
the epoch
"""
- host = attr.ib()
- port = attr.ib()
- priority = attr.ib(default=0)
- weight = attr.ib(default=0)
- expires = attr.ib(default=0)
+ host: bytes
+ port: int
+ priority: int = 0
+ weight: int = 0
+ expires: int = 0
-def _sort_server_list(server_list):
+def _sort_server_list(server_list: List[Server]) -> List[Server]:
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
- priority_map = {}
+ priority_map: Dict[int, List[Server]] = {}
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
@@ -103,11 +103,16 @@ class SrvResolver:
Args:
dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
- cache (dict): cache object
- get_time (callable): clock implementation. Should return seconds since the epoch
+ cache: cache object
+ get_time: clock implementation. Should return seconds since the epoch
"""
- def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ def __init__(
+ self,
+ dns_client=client,
+ cache: Dict[bytes, List[Server]] = SERVER_CACHE,
+ get_time: Callable[[], float] = time.time,
+ ):
self._dns_client = dns_client
self._cache = cache
self._get_time = get_time
@@ -116,7 +121,7 @@ class SrvResolver:
"""Look up a SRV record
Args:
- service_name (bytes): record to look up
+ service_name: record to look up
Returns:
a list of the SRV records, or an empty list if none found
@@ -158,7 +163,7 @@ class SrvResolver:
and answers[0].payload
and answers[0].payload.target == dns.Name(b".")
):
- raise ConnectError("Service %s unavailable" % service_name)
+ raise ConnectError(f"Service {service_name!r} unavailable")
servers = []
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a3f31452d0..6fd88bde20 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri)
- pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
request_path = parsed_uri.originForm
should_skip_proxy = False
@@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
)
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
- pool_key = ("http-proxy", self.http_proxy_endpoint)
+ pool_key = "http-proxy"
endpoint = self.http_proxy_endpoint
request_path = uri
elif (
--
cgit 1.5.1
From 220f901229a506a82aedc51c5923768bf935ea4f Mon Sep 17 00:00:00 2001
From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com>
Date: Thu, 19 Aug 2021 11:25:05 +0200
Subject: Remove not needed database updates in modify user admin API (#10627)
---
changelog.d/10627.misc | 1 +
docs/admin_api/user_admin_api.md | 8 +++-
synapse/rest/admin/users.py | 55 ++++++++++++++---------
synapse/storage/databases/main/registration.py | 25 ++++++++---
tests/rest/admin/test_user.py | 62 ++++++++++++++++++++++++--
5 files changed, 118 insertions(+), 33 deletions(-)
create mode 100644 changelog.d/10627.misc
(limited to 'synapse')
diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc
new file mode 100644
index 0000000000..e6d314976e
--- /dev/null
+++ b/changelog.d/10627.misc
@@ -0,0 +1 @@
+Remove not needed database updates in modify user admin API.
\ No newline at end of file
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 6a9335d6ec..60dc913915 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -21,11 +21,15 @@ It returns a JSON body like the following:
"threepids": [
{
"medium": "email",
- "address": ""
+ "address": "",
+ "added_at": 1586458409743,
+ "validated_at": 1586458409743
},
{
"medium": "email",
- "address": ""
+ "address": "",
+ "added_at": 1586458409743,
+ "validated_at": 1586458409743
}
],
"avatar_url": "",
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3c8a0c6883..c1a1ba645e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
if not isinstance(deactivate, bool):
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
- # convert into List[Tuple[str, str]]
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
if external_ids is not None:
- new_external_ids = []
- for external_id in external_ids:
- new_external_ids.append(
- (external_id["auth_provider"], external_id["external_id"])
- )
+ new_external_ids = {
+ (external_id["auth_provider"], external_id["external_id"])
+ for external_id in external_ids
+ }
+
+ # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+ if threepids is not None:
+ new_threepids = {
+ (threepid["medium"], threepid["address"]) for threepid in threepids
+ }
if user: # modify user
if "displayname" in body:
@@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
)
if threepids is not None:
- # remove old threepids from user
- old_threepids = await self.store.user_get_threepids(user_id)
- for threepid in old_threepids:
+ # get changed threepids (added and removed)
+ # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+ cur_threepids = {
+ (threepid["medium"], threepid["address"])
+ for threepid in await self.store.user_get_threepids(user_id)
+ }
+ add_threepids = new_threepids - cur_threepids
+ del_threepids = cur_threepids - new_threepids
+
+ # remove old threepids
+ for medium, address in del_threepids:
try:
await self.auth_handler.delete_threepid(
- user_id, threepid["medium"], threepid["address"], None
+ user_id, medium, address, None
)
except Exception:
logger.exception("Failed to remove threepids")
raise SynapseError(500, "Failed to remove threepids")
- # add new threepids to user
+ # add new threepids
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in add_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if external_ids is not None:
# get changed external_ids (added and removed)
- cur_external_ids = await self.store.get_external_ids_by_user(user_id)
- add_external_ids = set(new_external_ids) - set(cur_external_ids)
- del_external_ids = set(cur_external_ids) - set(new_external_ids)
+ cur_external_ids = set(
+ await self.store.get_external_ids_by_user(user_id)
+ )
+ add_external_ids = new_external_ids - cur_external_ids
+ del_external_ids = cur_external_ids - new_external_ids
# remove old external_ids
for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
current_time = self.hs.get_clock().time_msec()
- for threepid in threepids:
+ for medium, address in new_threepids:
await self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], current_time
+ user_id, medium, address, current_time
)
if (
self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
- device_display_name=threepid["address"],
- pushkey=threepid["address"],
+ device_display_name=address,
+ pushkey=address,
lang=None, # We don't know a user's language here
data={},
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..469dd53e0c 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return user_id
- def get_user_id_by_threepid_txn(self, txn, medium, address):
+ def get_user_id_by_threepid_txn(
+ self, txn, medium: str, address: str
+ ) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- str|None: user id or None if no user id/threepid mapping exists
+ user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"]
return None
- async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ async def user_add_threepid(
+ self,
+ user_id: str,
+ medium: str,
+ address: str,
+ validated_at: int,
+ added_at: int,
+ ) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id):
+ async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids",
)
- async def user_delete_threepid(self, user_id, medium, address) -> None:
+ async def user_delete_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef77275238..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
)
+ self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test setting threepid for an other user.
"""
- # Delete old and add new threepid to user
+ # Add two threepids to user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob1@bob.bob"},
+ {"medium": "email", "address": "bob2@bob.bob"},
+ ],
+ },
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ sorted_result = sorted(
+ channel.json_body["threepids"], key=lambda k: k["address"]
+ )
+ self.assertEqual("email", sorted_result[0]["medium"])
+ self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+ self.assertEqual("email", sorted_result[1]["medium"])
+ self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove a threepid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob2@bob.bob"},
+ {"medium": "email", "address": "bob3@bob.bob"},
+ ],
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Remove threepids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self._check_fields(channel.json_body)
def test_set_external_id(self):
"""
@@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
channel.json_body["external_ids"],
[
--
cgit 1.5.1
From b5fef6054a3d5763a031440ad4c56fc97b12d2aa Mon Sep 17 00:00:00 2001
From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com>
Date: Thu, 19 Aug 2021 11:40:40 +0200
Subject: Support MSC3283: Expose `enable_set_displayname` in capabilities
(#10452)
---
changelog.d/10452.feature | 1 +
synapse/config/experimental.py | 3 +
synapse/rest/client/capabilities.py | 11 +++
tests/rest/client/v2_alpha/test_capabilities.py | 109 +++++++++++++++++++-----
4 files changed, 101 insertions(+), 23 deletions(-)
create mode 100644 changelog.d/10452.feature
(limited to 'synapse')
diff --git a/changelog.d/10452.feature b/changelog.d/10452.feature
new file mode 100644
index 0000000000..f332b383e3
--- /dev/null
+++ b/changelog.d/10452.feature
@@ -0,0 +1 @@
+Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities.
\ No newline at end of file
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b918fb15b0..a85d8a5dc6 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -39,5 +39,8 @@ class ExperimentalConfig(Config):
# MSC3244 (room version capabilities)
self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False)
+ # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
+ self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
+
# MSC3266 (room summary api)
self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 88e3aac797..093549512e 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -61,6 +61,17 @@ class CapabilitiesRestServlet(RestServlet):
"org.matrix.msc3244.room_capabilities"
] = MSC3244_CAPABILITIES
+ if self.config.experimental.msc3283_enabled:
+ response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
+ "enabled": self.config.enable_set_displayname
+ }
+ response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
+ "enabled": self.config.enable_set_avatar_url
+ }
+ response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
+ "enabled": self.config.enable_3pid_changes
+ }
+
return 200, response
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index ad83b3d2ff..ac31e5ceaf 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
return hs
+ def prepare(self, reactor, clock, hs):
+ self.localpart = "user"
+ self.password = "pass"
+ self.user = self.register_user(self.localpart, self.password)
+
def test_check_auth_required(self):
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self):
- self.register_user("user", "pass")
- access_token = self.login("user", "pass")
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
def test_get_change_password_capabilities_password_login(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
- access_token = self.login(user, password)
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -102,13 +96,85 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
+ def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+ """Test that per default msc3283 is disabled server returns `m.change_password`."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+ self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+ self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+ @override_config({"experimental_features": {"msc3283_enabled": True}})
+ def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+ """Test if msc3283 is enabled server returns capabilities."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled(self):
+ """Test if set displayname is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+ @override_config(
+ {
+ "enable_3pid_changes": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_change_3pid_capabilities_3pid_disabled(self):
+ """Test if change 3pid is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
def test_get_does_not_include_msc3244_fields_by_default(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"experimental_features": {"msc3244_enabled": True}})
def test_get_does_include_msc3244_fields_when_enabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
--
cgit 1.5.1
From 000aa89be63c27092998eca03c97eaead21404cd Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Thu, 19 Aug 2021 11:12:55 -0400
Subject: Do not include rooms with an unknown room version in a sync response.
(#10644)
A user will still see this room if it is in a local cache, but it will
not reappear if clearing the cache and reloading.
---
changelog.d/10644.bugfix | 1 +
mypy.ini | 1 +
synapse/handlers/sync.py | 7 +-
synapse/storage/databases/main/roommember.py | 8 +-
synapse/storage/roommember.py | 1 +
tests/handlers/test_sync.py | 137 +++++++++++++++++++++++--
tests/replication/slave/storage/test_events.py | 1 +
7 files changed, 145 insertions(+), 11 deletions(-)
create mode 100644 changelog.d/10644.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10644.bugfix b/changelog.d/10644.bugfix
new file mode 100644
index 0000000000..d88a81fd82
--- /dev/null
+++ b/changelog.d/10644.bugfix
@@ -0,0 +1 @@
+Rooms with unsupported room versions are no longer returned via `/sync`.
diff --git a/mypy.ini b/mypy.ini
index 107f4de76c..90ade37b3f 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -90,6 +90,7 @@ files =
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/handlers/test_room_summary.py,
+ tests/handlers/test_sync.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_itertools.py,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b7b299961f..2203c45dcc 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,5 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,6 +30,7 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
@@ -1843,6 +1843,9 @@ class SyncHandler:
knocked = []
for event in room_list:
+ if event.room_version_id not in KNOWN_ROOM_VERSIONS:
+ continue
+
if event.membership == Membership.JOIN:
room_entries.append(
RoomSyncResultBuilder(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index c2f6b9d63d..c58a4b8690 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -386,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
+ INNER JOIN rooms AS r USING (room_id)
WHERE
user_id = ?
AND %s
@@ -397,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+ results = [RoomsForUser(*r) for r in txn]
return results
@@ -447,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
Returns the rooms the user is in currently, along with the stream
- ordering of the most recent join for that user and room.
+ ordering of the most recent join for that user and room, along with
+ the room version of the room.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9fad67ce48..2500381b7b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -30,6 +30,7 @@ class RoomsForUser:
membership: str
event_id: str
stream_ordering: int
+ room_version_id: str
@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 84f05f6c58..339c039914 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig
+from synapse.rest import admin
+from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
import tests.unittest
@@ -24,8 +31,14 @@ import tests.utils
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
- def prepare(self, reactor, clock, hs):
- self.hs = hs
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs: HomeServer):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
@@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ def test_unknown_room_version(self):
+ """
+ A room with an unknown room version should not break sync (and should be excluded).
+ """
+ inviter = self.register_user("creator", "pass", admin=True)
+ inviter_tok = self.login("@creator:test", "pass")
+
+ user = self.register_user("user", "pass")
+ tok = self.login("user", "pass")
+
+ # Do an initial sync on a different device.
+ requester = create_requester(user)
+ initial_result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user, device_id="dev")
+ )
+ )
+
+ # Create a room as the user.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Invite the user to the room as someone else.
+ invite_room = self.helper.create_room_as(inviter, tok=inviter_tok)
+ self.helper.invite(invite_room, targ=user, tok=inviter_tok)
+
+ knock_room = self.helper.create_room_as(
+ inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok
+ )
+ self.helper.send_state(
+ knock_room,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=inviter_tok,
+ )
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/knock/%s" % (knock_room,),
+ b"{}",
+ tok,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # The rooms should appear in the sync response.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Test a incremental sync (by providing a since_token).
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertIn(joined_room, [r.room_id for r in result.joined])
+ self.assertIn(invite_room, [r.room_id for r in result.invited])
+ self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+ # Poke the database and update the room version to an unknown one.
+ for room_id in (joined_room, invite_room, knock_room):
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ # Blow away caches (supported room versions can only change due to a restart).
+ self.get_success(
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ )
+ self.store._get_event_cache.clear()
+
+ # The rooms should be excluded from the sync response.
+ # Get a new request key.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester, sync_config=generate_sync_config(user)
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+ # The rooms should also not be in an incremental sync.
+ result = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config=generate_sync_config(user, device_id="dev"),
+ since_token=initial_result.next_batch,
+ )
+ )
+ self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+ self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+ self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+
+_request_key = 0
+
-def generate_sync_config(user_id: str) -> SyncConfig:
+def generate_sync_config(
+ user_id: str, device_id: Optional[str] = "device_id"
+) -> SyncConfig:
+ """Generate a sync config (with a unique request key)."""
+ global _request_key
+ _request_key += 1
return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ user=UserID.from_string(user_id),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
- request_key="request_key",
- device_id="device_id",
+ request_key=("request_key", _request_key),
+ device_id=device_id,
)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 8d1b0606c4..b25a06b427 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"invite",
event.event_id,
event.internal_metadata.stream_ordering,
+ RoomVersions.V1.identifier,
)
],
)
--
cgit 1.5.1
From 50af1efe4be8c93ee1fd642b60ab66d32317827b Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 19 Aug 2021 17:31:40 +0100
Subject: Extract `_resolve_state_at_missing_prevs` (#10624)
This is a follow-up to #10615: it takes the code that constructs the state at a backwards extremity, and extracts it to a separate method.
---
changelog.d/10624.misc | 1 +
synapse/handlers/federation.py | 229 ++++++++++++++++++++++-------------------
2 files changed, 125 insertions(+), 105 deletions(-)
create mode 100644 changelog.d/10624.misc
(limited to 'synapse')
diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10624.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 529d025c39..de86918b7b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -285,12 +285,12 @@ class FederationHandler(BaseHandler):
# - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier():
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
+ if sent_to_us_directly:
+ prevs = set(pdu.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
- if missing_prevs:
- if sent_to_us_directly:
+ if missing_prevs:
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("min_depth: %d", min_depth)
@@ -351,106 +351,8 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
- else:
- # We don't have all of the prev_events for this event.
- #
- # In this case, we need to fall back to asking another server in the
- # federation for the state at this event. That's ok provided we then
- # resolve the state against other bits of the DAG before using it (which
- # will ensure that you can't just take over a room by sending an event,
- # withholding its prev_events, and declaring yourself to be an admin in
- # the subsequent state request).
- #
- # Since we're pulling this event as a missing prev_event, then clearly
- # this event is not going to become the only forward-extremity and we are
- # guaranteed to resolve its state against our existing forward
- # extremities, so that should be fine.
- #
- # XXX this really feels like it could/should be merged with the above,
- # but there is an interaction with min_depth that I'm not really
- # following.
- logger.info(
- "Event %s is missing prev_events %s: calculating state for a "
- "backwards extremity",
- event_id,
- shortstr(missing_prevs),
- )
-
- # Calculate the state after each of the previous events, and
- # resolve them to find the correct state at the current event.
- event_map = {event_id: pdu}
- try:
- # Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(
- room_id, seen
- )
-
- # state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
-
- # Ask the remote server for the states we don't
- # know about
- for p in missing_prevs:
- logger.info(
- "Requesting state after missing prev_event %s", p
- )
-
- with nested_logging_context(p):
- # note that if any of the missing prevs share missing state or
- # auth events, the requests to fetch those events are deduped
- # by the get_pdu_cache in federation_client.
- remote_state = (
- await self._get_state_after_missing_prev_event(
- origin, room_id, p
- )
- )
-
- remote_state_map = {
- (x.type, x.state_key): x.event_id
- for x in remote_state
- }
- state_maps.append(remote_state_map)
-
- for x in remote_state:
- event_map[x.event_id] = x
-
- room_version = await self.store.get_room_version_id(room_id)
- state_map = await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
- )
-
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
- except Exception:
- logger.warning(
- "Error attempting to resolve state at missing "
- "prev_events",
- exc_info=True,
- )
- raise FederationError(
- "ERROR",
- 403,
- "We can't get valid state history.",
- affected=event_id,
- )
+ else:
+ state = await self._resolve_state_at_missing_prevs(origin, pdu)
# A second round of checks for all events. Check that the event passes auth
# based on `auth_events`, this allows us to assert that the event would
@@ -1493,6 +1395,123 @@ class FederationHandler(BaseHandler):
event_infos,
)
+ async def _resolve_state_at_missing_prevs(
+ self, dest: str, event: EventBase
+ ) -> Optional[Iterable[EventBase]]:
+ """Calculate the state at an event with missing prev_events.
+
+ This is used when we have pulled a batch of events from a remote server, and
+ still don't have all the prev_events.
+
+ If we already have all the prev_events for `event`, this method does nothing.
+
+ Otherwise, the missing prevs become new backwards extremities, and we fall back
+ to asking the remote server for the state after each missing `prev_event`,
+ and resolving across them.
+
+ That's ok provided we then resolve the state against other bits of the DAG
+ before using it - in other words, that the received event `event` is not going
+ to become the only forwards_extremity in the room (which will ensure that you
+ can't just take over a room by sending an event, withholding its prev_events,
+ and declaring yourself to be an admin in the subsequent state request).
+
+ In other words: we should only call this method if `event` has been *pulled*
+ as part of a batch of missing prev events, or similar.
+
+ Params:
+ dest: the remote server to ask for state at the missing prevs. Typically,
+ this will be the server we got `event` from.
+ event: an event to check for missing prevs.
+
+ Returns:
+ if we already had all the prev events, `None`. Otherwise, returns a list of
+ the events in the state at `event`.
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ prevs = set(event.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ return None
+
+ logger.info(
+ "Event %s is missing prev_events %s: calculating state for a "
+ "backwards extremity",
+ event_id,
+ shortstr(missing_prevs),
+ )
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
+ event_map = {event_id: event}
+ try:
+ # Get the state of the events we know about
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ state_maps: List[StateMap[str]] = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in missing_prevs:
+ logger.info("Requesting state after missing prev_event %s", p)
+
+ with nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state = await self._get_state_after_missing_prev_event(
+ dest, room_id, p
+ )
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
+
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
+
+ # We need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = await self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ event_map.update(evs)
+
+ state = [event_map[e] for e in state_map.values()]
+ except Exception:
+ logger.warning(
+ "Error attempting to resolve state at missing prev_events",
+ exc_info=True,
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=event_id,
+ )
+ return state
+
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
--
cgit 1.5.1
From e81d62009ee091d69d56bca702ba0edd39becb1c Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 19 Aug 2021 18:05:12 +0100
Subject: Split `on_receive_pdu` in half (#10640)
Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much.
---
changelog.d/10640.misc | 1 +
synapse/federation/federation_server.py | 4 +-
synapse/handlers/federation.py | 236 +++++++++++++++++++-------------
tests/test_federation.py | 10 +-
4 files changed, 142 insertions(+), 109 deletions(-)
create mode 100644 changelog.d/10640.misc
(limited to 'synapse')
diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10640.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index afd8f8580a..e1b58d40c5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self.handler.on_receive_pdu(
- origin, event, sent_to_us_directly=True
- )
+ await self.handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index de86918b7b..246df43501 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- async def on_receive_pdu(
- self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
- ) -> None:
- """Process a PDU received via a federation /send/ transaction, or
- via backfill of missing prev_events
+ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+ """Process a PDU received via a federation /send/ transaction
Args:
origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
pdu: received PDU
- sent_to_us_directly: True if this event was pushed to us; False if
- we pulled it as the result of a missing prev_event.
"""
room_id = pdu.room_id
@@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
)
return None
- state = None
-
# Check that the event passes auth based on the state at the event. This is
# done for events that are to be added to the timeline (non-outliers).
#
@@ -285,83 +278,72 @@ class FederationHandler(BaseHandler):
# - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier():
- if sent_to_us_directly:
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if missing_prevs:
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("min_depth: %d", min_depth)
-
- if min_depth is not None and pdu.depth > min_depth:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
+ prevs = set(pdu.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
+
+ if min_depth is not None and pdu.depth > min_depth:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring room lock to fetch %d missing prev_events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "Acquiring room lock to fetch %d missing prev_events: %s",
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
- shortstr(missing_prevs),
)
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired room lock to fetch %d missing prev_events",
- len(missing_prevs),
+
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
)
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ ) from e
- try:
- await self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
- except Exception as e:
- raise Exception(
- "Error fetching missing prev_events for %s: %s"
- % (event_id, e)
- ) from e
-
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- logger.info("Found all missing prev_events")
-
- if missing_prevs:
- # since this event was pushed to us, it is possible for it to
- # become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very
- # bad. Further, if the event was pushed to us, there is no excuse
- # for us not to have all the prev_events. (XXX: apart from
- # min_depth?)
- #
- # We therefore reject any such events.
- logger.warning(
- "Rejecting: failed to fetch %d prev events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
- else:
- state = await self._resolve_state_at_missing_prevs(origin, pdu)
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
+ logger.warning(
+ "Rejecting: failed to fetch %d prev events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
- # A second round of checks for all events. Check that the event passes auth
- # based on `auth_events`, this allows us to assert that the event would
- # have been allowed at some point. If an event passes this check its OK
- # for it to be used as part of a returned `/state` request, as either
- # a) we received the event as part of the original join and so trust it, or
- # b) we'll do a state resolution with existing state before it becomes
- # part of the "current state", which adds more protection.
- await self._process_received_pdu(origin, pdu, state=state)
+ await self._process_received_pdu(origin, pdu, state=None)
async def _get_missing_events_for_pdu(
self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@@ -461,24 +443,7 @@ class FederationHandler(BaseHandler):
return
logger.info("Got %d prev_events", len(missing_events))
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- missing_events.sort(key=lambda x: x.depth)
-
- for ev in missing_events:
- logger.info("Handling received prev_event %s", ev)
- with nested_logging_context(ev.event_id):
- try:
- await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
- except FederationError as e:
- if e.code == 403:
- logger.warning(
- "Received prev_event %s failed history check.",
- ev.event_id,
- )
- else:
- raise
+ await self._process_pulled_events(origin, missing_events)
async def _get_state_for_room(
self,
@@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler):
event_infos,
)
+ async def _process_pulled_events(
+ self, origin: str, events: Iterable[EventBase]
+ ) -> None:
+ """Process a batch of events we have pulled from a remote server
+
+ Pulls in any events required to auth the events, persists the received events,
+ and notifies clients, if appropriate.
+
+ Assumes the events have already had their signatures and hashes checked.
+
+ Params:
+ origin: The server we received these events from
+ events: The received events.
+ """
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ for ev in sorted_events:
+ with nested_logging_context(ev.event_id):
+ await self._process_pulled_event(origin, ev)
+
+ async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+ """Process a single event that we have pulled from a remote server
+
+ Pulls in any events required to auth the event, persists the received event,
+ and notifies clients, if appropriate.
+
+ Assumes the event has already had its signatures and hashes checked.
+
+ This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+ logic in the case that we are missing prev_events (in particular, it just
+ requests the state at that point, rather than triggering a get_missing_events) -
+ so is appropriate when we have pulled the event from a remote server, rather
+ than having it pushed to us.
+
+ Params:
+ origin: The server we received this event from
+ events: The received event
+ """
+ logger.info("Processing pulled event %s", event)
+
+ # these should not be outliers.
+ assert not event.internal_metadata.is_outlier()
+
+ event_id = event.event_id
+
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ try:
+ self._sanity_check_event(event)
+ except SynapseError as err:
+ logger.warning("Event %s failed sanity check: %s", event_id, err)
+ return
+
+ try:
+ state = await self._resolve_state_at_missing_prevs(origin, event)
+ await self._process_received_pdu(origin, event, state=state)
+ except FederationError as e:
+ if e.code == 403:
+ logger.warning("Pulled event %s failed history check.", event_id)
+ else:
+ raise
+
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
@@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler):
p,
)
with nested_logging_context(p.event_id):
- await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3785799f46..348fcb72a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Send the join, it should return None (which is not an error)
self.assertEqual(
- self.get_success(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
- ),
+ self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
None,
)
@@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
with LoggingContext("test-context"):
failure = self.get_failure(
- self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
- ),
+ self.handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)
--
cgit 1.5.1
From ee3b2ac59a5645cb213b3c11c473613b80fe1e0f Mon Sep 17 00:00:00 2001
From: David Robertson
Date: Fri, 20 Aug 2021 15:47:03 +0100
Subject: Validate device_keys for C-S /keys/query requests (#10593)
* Validate device_keys for C-S /keys/query requests
Closes #10354
A small, not particularly critical fix. I'm interested in seeing if we
can find a more systematic approach though. #8445 is the place for any discussion.
---
changelog.d/10593.bugfix | 1 +
synapse/api/errors.py | 8 ++++
synapse/rest/client/keys.py | 16 ++++++-
tests/rest/client/v2_alpha/test_keys.py | 77 +++++++++++++++++++++++++++++++++
4 files changed, 101 insertions(+), 1 deletion(-)
create mode 100644 changelog.d/10593.bugfix
create mode 100644 tests/rest/client/v2_alpha/test_keys.py
(limited to 'synapse')
diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix
new file mode 100644
index 0000000000..492e58a7a8
--- /dev/null
+++ b/changelog.d/10593.bugfix
@@ -0,0 +1 @@
+Reject Client-Server /keys/query requests which provide device_ids incorrectly.
\ No newline at end of file
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index dc662bca83..9480f448d7 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
return cs_error(self.msg, self.errcode)
+class InvalidAPICallError(SynapseError):
+ """You called an existing API endpoint, but fed that endpoint
+ invalid or incomplete data."""
+
+ def __init__(self, msg: str):
+ super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
+
+
class ProxiedRequestError(SynapseError):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d40..012491f597 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,8 +15,9 @@
# limitations under the License.
import logging
+from typing import Any
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet):
device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ device_keys = body.get("device_keys")
+ if not isinstance(device_keys, dict):
+ raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+ def is_list_of_strings(values: Any) -> bool:
+ return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+ if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+ raise InvalidAPICallError(
+ "'device_keys' values must be a list of strings",
+ )
+
result = await self.e2e_keys_handler.query_devices(
body, timeout, user_id, device_id
)
diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py
new file mode 100644
index 0000000000..80a4e728ff
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_keys.py
@@ -0,0 +1,77 @@
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ keys.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def test_rejects_device_id_ice_key_outside_of_list(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: "device_id1",
+ },
+ },
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_rejects_device_key_given_as_map_to_bool(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: {
+ "device_id1": True,
+ },
+ },
+ },
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_requires_device_key(self):
+ """`device_keys` is required. We should complain if it's missing."""
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {},
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
--
cgit 1.5.1
From 947dbbdfd1e0029da66f956d277b7c089928e1e7 Mon Sep 17 00:00:00 2001
From: Callum Brown
Date: Sat, 21 Aug 2021 22:14:43 +0100
Subject: Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown
This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
---
changelog.d/10142.feature | 1 +
docs/SUMMARY.md | 1 +
docs/sample_config.yaml | 15 +
.../admin_api/registration_tokens.md | 295 +++++++++
docs/workers.md | 1 +
synapse/api/constants.py | 1 +
synapse/app/generic_worker.py | 6 +-
synapse/config/ratelimiting.py | 11 +
synapse/config/registration.py | 15 +
synapse/handlers/ui_auth/__init__.py | 5 +
synapse/handlers/ui_auth/checkers.py | 65 ++
synapse/res/templates/registration_token.html | 23 +
synapse/rest/admin/__init__.py | 8 +
synapse/rest/admin/registration_tokens.py | 321 ++++++++++
synapse/rest/client/auth.py | 24 +
synapse/rest/client/register.py | 72 +++
synapse/storage/databases/main/registration.py | 316 +++++++++
synapse/storage/databases/main/ui_auth.py | 43 ++
.../main/delta/63/01create_registration_tokens.sql | 23 +
tests/rest/admin/test_registration_tokens.py | 710 +++++++++++++++++++++
tests/rest/client/test_register.py | 434 +++++++++++++
21 files changed, 2389 insertions(+), 1 deletion(-)
create mode 100644 changelog.d/10142.feature
create mode 100644 docs/usage/administration/admin_api/registration_tokens.md
create mode 100644 synapse/res/templates/registration_token.html
create mode 100644 synapse/rest/admin/registration_tokens.py
create mode 100644 synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
create mode 100644 tests/rest/admin/test_registration_tokens.py
(limited to 'synapse')
diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature
new file mode 100644
index 0000000000..5353f6269d
--- /dev/null
+++ b/changelog.d/10142.feature
@@ -0,0 +1 @@
+Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index 634bc833ab..4fcd2b7852 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -53,6 +53,7 @@
- [Media](admin_api/media_admin_api.md)
- [Purge History](admin_api/purge_history_api.md)
- [Register Users](admin_api/register_api.md)
+ - [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
- [Manipulate Room Membership](admin_api/room_membership.md)
- [Rooms](admin_api/rooms.md)
- [Server Notices](admin_api/server_notices.md)
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 2b0c453242..935841dbfa 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
+# - one for checking the validity of registration tokens that ratelimits
+# requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
+#rc_registration_token_validity:
+# per_second: 0.1
+# burst_count: 5
+#
#rc_login:
# address:
# per_second: 0.17
@@ -1169,6 +1175,15 @@ url_preview_accept_language:
#
#enable_3pid_lookup: true
+# Require users to submit a token during registration.
+# Tokens can be managed using the admin API:
+# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+# Note that `enable_registration` must be set to `true`.
+# Disabling this option will not delete any tokens previously generated.
+# Defaults to false. Uncomment the following to require tokens:
+#
+#registration_requires_token: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md
new file mode 100644
index 0000000000..828c0277d6
--- /dev/null
+++ b/docs/usage/administration/admin_api/registration_tokens.md
@@ -0,0 +1,295 @@
+# Registration Tokens
+
+This API allows you to manage tokens which can be used to authenticate
+registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
+To use it, you will need to enable the `registration_requires_token` config
+option, and authenticate by providing an `access_token` for a server admin:
+see [Admin API](../../usage/administration/admin_api).
+Note that this API is still experimental; not all clients may support it yet.
+
+
+## Registration token objects
+
+Most endpoints make use of JSON objects that contain details about tokens.
+These objects have the following fields:
+- `token`: The token which can be used to authenticate registration.
+- `uses_allowed`: The number of times the token can be used to complete a
+ registration before it becomes invalid.
+- `pending`: The number of pending uses the token has. When someone uses
+ the token to authenticate themselves, the pending counter is incremented
+ so that the token is not used more than the permitted number of times.
+ When the person completes registration the pending counter is decremented,
+ and the completed counter is incremented.
+- `completed`: The number of times the token has been used to successfully
+ complete a registration.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ To convert this into a human-readable form you can remove the milliseconds
+ and use the `date` command. For example, `date -d '@1625394937'`.
+
+
+## List all tokens
+
+Lists all tokens and details about them. If the request is successful, the top
+level JSON object will have a `registration_tokens` key which is an array of
+registration token objects.
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+
+Optional query parameters:
+- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
+ If `false`, only tokens that have expired or have had all uses exhausted are
+ returned. If omitted, all tokens are returned regardless of validity.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+```
+200 OK
+
+{
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "pqrs",
+ "uses_allowed": 2,
+ "pending": 1,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
+ }
+ ]
+}
+```
+
+Example using the `valid` query parameter:
+
+```
+GET /_synapse/admin/v1/registration_tokens?valid=false
+```
+```
+200 OK
+
+{
+ "registration_tokens": [
+ {
+ "token": "pqrs",
+ "uses_allowed": 2,
+ "pending": 1,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC
+ }
+ ]
+}
+```
+
+
+## Get one token
+
+Get details about a single token. If the request is successful, the response
+body will be a registration token object.
+
+```
+GET /_synapse/admin/v1/registration_tokens/
+```
+
+Path parameters:
+- `token`: The registration token to return details of.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens/abcd
+```
+```
+200 OK
+
+{
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+}
+```
+
+
+## Create token
+
+Create a new registration token. If the request is successful, the newly created
+token will be returned as a registration token object in the response body.
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+```
+
+The request body must be a JSON object and can contain the following fields:
+- `token`: The registration token. A string of no more than 64 characters that
+ consists only of characters matched by the regex `[A-Za-z0-9-_]`.
+ Default: randomly generated.
+- `uses_allowed`: The integer number of times the token can be used to complete
+ a registration before it becomes invalid.
+ Default: `null` (unlimited uses).
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ You could use, for example, `date '+%s000' -d 'tomorrow'`.
+ Default: `null` (token does not expire).
+- `length`: The length of the token randomly generated if `token` is not
+ specified. Must be between 1 and 64 inclusive. Default: `16`.
+
+If a field is omitted the default is used.
+
+Example using defaults:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{}
+```
+```
+200 OK
+
+{
+ "token": "0M-9jbkf2t_Tgiw1",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+}
+```
+
+Example specifying some fields:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{
+ "token": "defg",
+ "uses_allowed": 1
+}
+```
+```
+200 OK
+
+{
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+}
+```
+
+
+## Update token
+
+Update the number of allowed uses or expiry time of a token. If the request is
+successful, the updated token will be returned as a registration token object
+in the response body.
+
+```
+PUT /_synapse/admin/v1/registration_tokens/
+```
+
+Path parameters:
+- `token`: The registration token to update.
+
+The request body must be a JSON object and can contain the following fields:
+- `uses_allowed`: The integer number of times the token can be used to complete
+ a registration before it becomes invalid. By setting `uses_allowed` to `0`
+ the token can be easily made invalid without deleting it.
+ If `null` the token will have an unlimited number of uses.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+ milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+ If `null` the token will not expire.
+
+If a field is omitted its value is not modified.
+
+Example:
+
+```
+PUT /_synapse/admin/v1/registration_tokens/defg
+
+{
+ "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC
+}
+```
+```
+200 OK
+
+{
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+}
+```
+
+
+## Delete token
+
+Delete a registration token. If the request is successful, the response body
+will be an empty JSON object.
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/
+```
+
+Path parameters:
+- `token`: The registration token to delete.
+
+Example:
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/wxyz
+```
+```
+200 OK
+
+{}
+```
+
+
+## Errors
+
+If a request fails a "standard error response" will be returned as defined in
+the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
+
+For example, if the token specified in a path parameter does not exist a
+`404 Not Found` error will be returned.
+
+```
+GET /_synapse/admin/v1/registration_tokens/1234
+```
+```
+404 Not Found
+
+{
+ "errcode": "M_NOT_FOUND",
+ "error": "No such registration token: 1234"
+}
+```
diff --git a/docs/workers.md b/docs/workers.md
index 2e63f03452..3121241894 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -236,6 +236,7 @@ expressions:
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(r0|unstable)/register$
+ ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
# Event sending requests
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e0e24fddac..829061c870 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class LoginType:
TERMS = "m.login.terms"
SSO = "m.login.sso"
DUMMY = "m.login.dummy"
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
# This is used in the `type` parameter for /register when called by
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 845e6a8220..fd2626dbe1 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
ProfileRestServlet,
)
from synapse.rest.client.push_rule import PushRuleRestServlet
-from synapse.rest.client.register import RegisterRestServlet
+from synapse.rest.client.register import (
+ RegisterRestServlet,
+ RegistrationTokenValidityRestServlet,
+)
from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.rest.client.voip import VoipRestServlet
@@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource)
+ RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 7a8d5851c4..f856327bd8 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -79,6 +79,11 @@ class RatelimitConfig(Config):
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_registration_token_validity = RateLimitConfig(
+ config.get("rc_registration_token_validity", {}),
+ defaults={"per_second": 0.1, "burst_count": 5},
+ )
+
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@@ -143,6 +148,8 @@ class RatelimitConfig(Config):
# is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
+ # - one for checking the validity of registration tokens that ratelimits
+ # requests based on the client's IP address.
# - one for login that ratelimits login requests based on the client's IP
# address.
# - one for login that ratelimits login requests based on the account the
@@ -171,6 +178,10 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
#
+ #rc_registration_token_validity:
+ # per_second: 0.1
+ # burst_count: 5
+ #
#rc_login:
# address:
# per_second: 0.17
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0ad919b139..7cffdacfa5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,9 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
+ self.registration_requires_token = config.get(
+ "registration_requires_token", False
+ )
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -140,6 +143,9 @@ class RegistrationConfig(Config):
"mechanism by removing the `access_token_lifetime` option."
)
+ # The fallback template used for authenticating using a registration token
+ self.registration_token_template = self.read_template("registration_token.html")
+
# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")
@@ -199,6 +205,15 @@ class RegistrationConfig(Config):
#
#enable_3pid_lookup: true
+ # Require users to submit a token during registration.
+ # Tokens can be managed using the admin API:
+ # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+ # Note that `enable_registration` must be set to `true`.
+ # Disabling this option will not delete any tokens previously generated.
+ # Defaults to false. Uncomment the following to require tokens:
+ #
+ #registration_requires_token: true
+
# If set, allows registration of standard or admin accounts by anyone who
# has the shared secret, even if registration is otherwise disabled.
#
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
# used by validate_user_via_ui_auth to store the mxid of the user we are validating
# for.
REQUEST_USER_ID = "request_user_id"
+
+ # used during registration to store the registration token used (if required) so that:
+ # - we can prevent a token being used twice by one session
+ # - we can 'use up' the token after registration has successfully completed
+ REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 270541cc76..d3828dec6b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return await self._check_threepid("msisdn", authdict)
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self.hs = hs
+ self._enabled = bool(hs.config.registration_requires_token)
+ self.store = hs.get_datastore()
+
+ def is_enabled(self) -> bool:
+ return self._enabled
+
+ async def check_auth(self, authdict: dict, clientip: str) -> Any:
+ if "token" not in authdict:
+ raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+ if not isinstance(authdict["token"], str):
+ raise LoginError(
+ 400, "Registration token must be a string", Codes.INVALID_PARAM
+ )
+ if "session" not in authdict:
+ raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+ # Get these here to avoid cyclic dependencies
+ from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+ auth_handler = self.hs.get_auth_handler()
+
+ session = authdict["session"]
+ token = authdict["token"]
+
+ # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+ # return early to avoid incrementing `pending` again.
+ stored_token = await auth_handler.get_session_data(
+ session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+ )
+ if stored_token:
+ if token != stored_token:
+ raise LoginError(
+ 400, "Registration token has changed", Codes.INVALID_PARAM
+ )
+ else:
+ return token
+
+ if await self.store.registration_token_is_valid(token):
+ # Increment pending counter, so that if token has limited uses it
+ # can't be used up by someone else in the meantime.
+ await self.store.set_registration_token_pending(token)
+ # Store the token in the UIA session, so that once registration
+ # is complete `completed` can be incremented.
+ await auth_handler.set_session_data(
+ session,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ token,
+ )
+ # The token will be stored as the result of the authentication stage
+ # in ui_auth_sessions_credentials. This allows the pending counter
+ # for tokens to be decremented when expired sessions are deleted.
+ return token
+ else:
+ raise LoginError(
+ 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+ )
+
+
INTERACTIVE_AUTH_CHECKERS = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
+ RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
new file mode 100644
index 0000000000..4577ce1702
--- /dev/null
+++ b/synapse/res/templates/registration_token.html
@@ -0,0 +1,23 @@
+
+
+Authentication
+
+
+
+
+
+
+
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7f3051aef1..6e1c8736e1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import (
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.registration_tokens import (
+ ListRegistrationTokensRestServlet,
+ NewRegistrationTokenRestServlet,
+ RegistrationTokenRestServlet,
+)
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
@@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomEventContextServlet(hs).register(http_server)
RateLimitRestServlet(hs).register(http_server)
UsernameAvailableRestServlet(hs).register(http_server)
+ ListRegistrationTokensRestServlet(hs).register(http_server)
+ NewRegistrationTokenRestServlet(hs).register(http_server)
+ RegistrationTokenRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 0000000000..5a1c929d85
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+ """List registration tokens.
+
+ To list all tokens:
+
+ GET /_synapse/admin/v1/registration_tokens
+
+ 200 OK
+
+ {
+ "registration_tokens": [
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ },
+ {
+ "token": "wxyz",
+ "uses_allowed": null,
+ "pending": 0,
+ "completed": 9,
+ "expiry_time": 1625394937000
+ }
+ ]
+ }
+
+ The optional query parameter `valid` can be used to filter the response.
+ If it is `true`, only valid tokens are returned. If it is `false`, only
+ tokens that have expired or have had all uses exhausted are returned.
+ If it is omitted, all tokens are returned regardless of validity.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ valid = parse_boolean(request, "valid")
+ token_list = await self.store.get_registration_tokens(valid)
+ return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+ """Create a new registration token.
+
+ For example, to create a token specifying some fields:
+
+ POST /_synapse/admin/v1/registration_tokens/new
+
+ {
+ "token": "defg",
+ "uses_allowed": 1
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": null
+ }
+
+ Defaults are used for any fields not specified.
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/new$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ # A string of all the characters allowed to be in a registration_token
+ self.allowed_chars = string.ascii_letters + string.digits + "-_"
+ self.allowed_chars_set = set(self.allowed_chars)
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+
+ if "token" in body:
+ token = body["token"]
+ if not isinstance(token, str):
+ raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+ if not (0 < len(token) <= 64):
+ raise SynapseError(
+ 400,
+ "token must not be empty and must not be longer than 64 characters",
+ Codes.INVALID_PARAM,
+ )
+ if not set(token).issubset(self.allowed_chars_set):
+ raise SynapseError(
+ 400,
+ "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+ Codes.INVALID_PARAM,
+ )
+
+ else:
+ # Get length of token to generate (default is 16)
+ length = body.get("length", 16)
+ if not isinstance(length, int):
+ raise SynapseError(
+ 400, "length must be an integer", Codes.INVALID_PARAM
+ )
+ if not (0 < length <= 64):
+ raise SynapseError(
+ 400,
+ "length must be greater than zero and not greater than 64",
+ Codes.INVALID_PARAM,
+ )
+
+ # Generate token
+ token = await self.store.generate_registration_token(
+ length, self.allowed_chars
+ )
+
+ uses_allowed = body.get("uses_allowed", None)
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+
+ expiry_time = body.get("expiry_time", None)
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+
+ created = await self.store.create_registration_token(
+ token, uses_allowed, expiry_time
+ )
+ if not created:
+ raise SynapseError(
+ 400, f"Token already exists: {token}", Codes.INVALID_PARAM
+ )
+
+ resp = {
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ }
+ return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+ """Retrieve, update, or delete the given token.
+
+ For example,
+
+ to retrieve a token:
+
+ GET /_synapse/admin/v1/registration_tokens/abcd
+
+ 200 OK
+
+ {
+ "token": "abcd",
+ "uses_allowed": 3,
+ "pending": 0,
+ "completed": 1,
+ "expiry_time": null
+ }
+
+
+ to update a token:
+
+ PUT /_synapse/admin/v1/registration_tokens/defg
+
+ {
+ "uses_allowed": 5,
+ "expiry_time": 4781243146000
+ }
+
+ 200 OK
+
+ {
+ "token": "defg",
+ "uses_allowed": 5,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": 4781243146000
+ }
+
+
+ to delete a token:
+
+ DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+ 200 OK
+
+ {}
+ """
+
+ PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Retrieve a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ token_info = await self.store.get_one_registration_token(token)
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+ """Update a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+ new_attributes = {}
+
+ # Only add uses_allowed to new_attributes if it is present and valid
+ if "uses_allowed" in body:
+ uses_allowed = body["uses_allowed"]
+ if not (
+ uses_allowed is None
+ or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+ ):
+ raise SynapseError(
+ 400,
+ "uses_allowed must be a non-negative integer or null",
+ Codes.INVALID_PARAM,
+ )
+ new_attributes["uses_allowed"] = uses_allowed
+
+ if "expiry_time" in body:
+ expiry_time = body["expiry_time"]
+ if not isinstance(expiry_time, (int, type(None))):
+ raise SynapseError(
+ 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+ )
+ if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+ raise SynapseError(
+ 400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+ )
+ new_attributes["expiry_time"] = expiry_time
+
+ if len(new_attributes) == 0:
+ # Nothing to update, get token info to return
+ token_info = await self.store.get_one_registration_token(token)
+ else:
+ token_info = await self.store.update_registration_token(
+ token, new_attributes
+ )
+
+ # If no result return a 404
+ if token_info is None:
+ raise NotFoundError(f"No such registration token: {token}")
+
+ return 200, token_info
+
+ async def on_DELETE(
+ self, request: SynapseRequest, token: str
+ ) -> Tuple[int, JsonDict]:
+ """Delete a registration token."""
+ await assert_requester_is_admin(self.auth, request)
+
+ if await self.store.delete_registration_token(token):
+ return 200, {}
+
+ raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 73284e48ec..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
+ self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
# re-authenticate with their SSO provider.
html = await self.auth_handler.start_sso_ui_auth(request, session)
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ )
+
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet):
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+ elif stagetype == LoginType.REGISTRATION_TOKEN:
+ token = parse_string(request, "token", required=True)
+ authdict = {"session": session, "token": token}
+
+ try:
+ await self.auth_handler.add_oob_auth(
+ LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+ )
+ except LoginError as e:
+ html = self.registration_token_template.render(
+ session=session,
+ myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+ error=e.msg,
+ )
+ else:
+ html = self.success_template.render()
+
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
return 200, {"available": True}
+class RegistrationTokenValidityRestServlet(RestServlet):
+ """Check the validity of a registration token.
+
+ Example:
+
+ GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+ 200 OK
+
+ {
+ "valid": true
+ }
+ """
+
+ PATTERNS = client_patterns(
+ f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+ releases=(),
+ unstable=True,
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.ratelimiter = Ratelimiter(
+ store=self.store,
+ clock=hs.get_clock(),
+ rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+ burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+ )
+
+ async def on_GET(self, request):
+ await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
+ token = parse_string(request, "token", required=True)
+ valid = await self.store.registration_token_is_valid(token)
+
+ return 200, {"valid": valid}
+
+
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
)
if registered:
+ # Check if a token was used to authenticate registration
+ registration_token = await self.auth_handler.get_session_data(
+ session_id,
+ UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+ )
+ if registration_token:
+ # Increment the `completed` counter for the token
+ await self.store.use_registration_token(registration_token)
+ # Indicate that the token has been successfully used so that
+ # pending is not decremented again when expiring old UIA sessions.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id,
+ LoginType.REGISTRATION_TOKEN,
+ True,
+ )
+
await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
+ # Prepend registration token to all flows if we're requiring a token
+ if config.registration_requires_token:
+ for flow in flows:
+ flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
return flows
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
+ RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 469dd53e0c..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ async def registration_token_is_valid(self, token: str) -> bool:
+ """Checks if a token can be used to authenticate a registration.
+
+ Args:
+ token: The registration token to be checked
+ Returns:
+ True if the token is valid, False otherwise.
+ """
+ res = await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ )
+
+ # Check if the token exists
+ if res is None:
+ return False
+
+ # Check if the token has expired
+ now = self._clock.time_msec()
+ if res["expiry_time"] and res["expiry_time"] < now:
+ return False
+
+ # Check if the token has been used up
+ if (
+ res["uses_allowed"]
+ and res["pending"] + res["completed"] >= res["uses_allowed"]
+ ):
+ return False
+
+ # Otherwise, the token is valid
+ return True
+
+ async def set_registration_token_pending(self, token: str) -> None:
+ """Increment the pending registrations counter for a token.
+
+ Args:
+ token: The registration token pending use
+ """
+
+ def _set_registration_token_pending_txn(txn):
+ pending = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": pending + 1},
+ )
+
+ return await self.db_pool.runInteraction(
+ "set_registration_token_pending", _set_registration_token_pending_txn
+ )
+
+ async def use_registration_token(self, token: str) -> None:
+ """Complete a use of the given registration token.
+
+ The `pending` counter will be decremented, and the `completed`
+ counter will be incremented.
+
+ Args:
+ token: The registration token to be 'used'
+ """
+
+ def _use_registration_token_txn(txn):
+ # Normally, res is Optional[Dict[str, Any]].
+ # Override type because the return type is only optional if
+ # allow_none is True, and we don't want mypy throwing errors
+ # about None not being indexable.
+ res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ) # type: ignore
+
+ # Decrement pending and increment completed
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={
+ "completed": res["completed"] + 1,
+ "pending": res["pending"] - 1,
+ },
+ )
+
+ return await self.db_pool.runInteraction(
+ "use_registration_token", _use_registration_token_txn
+ )
+
+ async def get_registration_tokens(
+ self, valid: Optional[bool] = None
+ ) -> List[Dict[str, Any]]:
+ """List all registration tokens. Used by the admin API.
+
+ Args:
+ valid: If True, only valid tokens are returned.
+ If False, only invalid tokens are returned.
+ Default is None: return all tokens regardless of validity.
+
+ Returns:
+ A list of dicts, each containing details of a token.
+ """
+
+ def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ if valid is None:
+ # Return all tokens regardless of validity
+ txn.execute("SELECT * FROM registration_tokens")
+
+ elif valid:
+ # Select valid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+ "AND (expiry_time > ? OR expiry_time IS NULL)"
+ )
+ txn.execute(sql, [now])
+
+ else:
+ # Select invalid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "uses_allowed <= pending + completed OR expiry_time <= ?"
+ )
+ txn.execute(sql, [now])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "select_registration_tokens",
+ select_registration_tokens_txn,
+ self._clock.time_msec(),
+ valid,
+ )
+
+ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+ """Get info about the given registration token. Used by the admin API.
+
+ Args:
+ token: The token to retrieve information about.
+
+ Returns:
+ A dict, or None if token doesn't exist.
+ """
+ return await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ desc="get_one_registration_token",
+ )
+
+ async def generate_registration_token(
+ self, length: int, chars: str
+ ) -> Optional[str]:
+ """Generate a random registration token. Used by the admin API.
+
+ Args:
+ length: The length of the token to generate.
+ chars: A string of the characters allowed in the generated token.
+
+ Returns:
+ The generated token.
+
+ Raises:
+ SynapseError if a unique registration token could still not be
+ generated after a few tries.
+ """
+ # Make a few attempts at generating a unique token of the required
+ # length before failing.
+ for _i in range(3):
+ # Generate token
+ token = "".join(random.choices(chars, k=length))
+
+ # Check if the token already exists
+ existing_token = await self.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="token",
+ allow_none=True,
+ desc="check_if_registration_token_exists",
+ )
+
+ if existing_token is None:
+ # The generated token doesn't exist yet, return it
+ return token
+
+ raise SynapseError(
+ 500,
+ "Unable to generate a unique registration token. Try again with a greater length",
+ Codes.UNKNOWN,
+ )
+
+ async def create_registration_token(
+ self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+ ) -> bool:
+ """Create a new registration token. Used by the admin API.
+
+ Args:
+ token: The token to create.
+ uses_allowed: The number of times the token can be used to complete
+ a registration before it becomes invalid. A value of None indicates
+ unlimited uses.
+ expiry_time: The latest time the token is valid. Given as the
+ number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+ None indicates that the token does not expire.
+
+ Returns:
+ Whether the row was inserted or not.
+ """
+
+ def _create_registration_token_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token"],
+ allow_none=True,
+ )
+
+ if row is not None:
+ # Token already exists
+ return False
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ "registration_tokens",
+ values={
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ },
+ )
+
+ return True
+
+ return await self.db_pool.runInteraction(
+ "create_registration_token", _create_registration_token_txn
+ )
+
+ async def update_registration_token(
+ self, token: str, updatevalues: Dict[str, Optional[int]]
+ ) -> Optional[Dict[str, Any]]:
+ """Update a registration token. Used by the admin API.
+
+ Args:
+ token: The token to update.
+ updatevalues: A dict with the fields to update. E.g.:
+ `{"uses_allowed": 3}` to update just uses_allowed, or
+ `{"uses_allowed": 3, "expiry_time": None}` to update both.
+ This is passed straight to simple_update_one.
+
+ Returns:
+ A dict with all info about the token, or None if token doesn't exist.
+ """
+
+ def _update_registration_token_txn(txn):
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues=updatevalues,
+ )
+ except StoreError:
+ # Update failed because token does not exist
+ return None
+
+ # Get all info about the token so it can be sent in the response
+ return self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=[
+ "token",
+ "uses_allowed",
+ "pending",
+ "completed",
+ "expiry_time",
+ ],
+ allow_none=True,
+ )
+
+ return await self.db_pool.runInteraction(
+ "update_registration_token", _update_registration_token_txn
+ )
+
+ async def delete_registration_token(self, token: str) -> bool:
+ """Delete a registration token. Used by the admin API.
+
+ Args:
+ token: The token to delete.
+
+ Returns:
+ Whether the token was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ desc="delete_registration_token",
+ )
+ except StoreError:
+ # Deletion failed because token does not exist
+ return False
+
+ return True
+
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 0000000000..ee6cf958f4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+ token TEXT NOT NULL, -- The token that can be used for authentication.
+ uses_allowed INT, -- The total number of times this token can be used. NULL if no limit.
+ pending INT NOT NULL, -- The number of in progress registrations using this token.
+ completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+ expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+ UNIQUE (token)
+);
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
--
cgit 1.5.1
From 31dac7ffeeb02f68d1dbe068fd241239e02208dc Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Mon, 23 Aug 2021 08:00:25 -0400
Subject: Do not include stack traces for known exceptions when trying multiple
federation destinations. (#10662)
---
changelog.d/10662.misc | 1 +
synapse/federation/federation_client.py | 7 ++++++-
2 files changed, 7 insertions(+), 1 deletion(-)
create mode 100644 changelog.d/10662.misc
(limited to 'synapse')
diff --git a/changelog.d/10662.misc b/changelog.d/10662.misc
new file mode 100644
index 0000000000..593f9ceaad
--- /dev/null
+++ b/changelog.d/10662.misc
@@ -0,0 +1 @@
+Do not print out stack traces for network errors when fetching data over federation.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 29979414e3..44d9e8a5c7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ RequestSendFailed,
SynapseError,
UnsupportedRoomVersionError,
)
@@ -558,7 +559,11 @@ class FederationClient(FederationBase):
try:
return await callback(destination)
- except InvalidResponseError as e:
+ except (
+ RequestSendFailed,
+ InvalidResponseError,
+ NotRetryingDestination,
+ ) as e:
logger.warning("Failed to %s via %s: %s", description, destination, e)
except UnsupportedRoomVersionError:
raise
--
cgit 1.5.1
From 2af6d31b78109a989e27128ac655990c35b29d62 Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Mon, 23 Aug 2021 08:14:17 -0400
Subject: Addtional type hints for the REST servlets. (#10665)
---
changelog.d/10665.misc | 1 +
synapse/rest/client/account_validity.py | 39 ++++++------
synapse/rest/client/capabilities.py | 3 +-
synapse/rest/client/directory.py | 78 +++++++++++++++---------
synapse/rest/client/pusher.py | 22 ++++---
synapse/rest/client/read_marker.py | 15 ++++-
synapse/rest/client/room_upgrade_rest_servlet.py | 18 ++++--
synapse/rest/client/shared_rooms.py | 16 +++--
synapse/rest/client/tags.py | 25 ++++++--
synapse/rest/client/thirdparty.py | 36 +++++++----
synapse/rest/client/tokenrefresh.py | 14 ++++-
synapse/rest/client/user_directory.py | 17 +++---
synapse/rest/client/versions.py | 14 ++++-
synapse/rest/client/voip.py | 13 +++-
14 files changed, 204 insertions(+), 107 deletions(-)
create mode 100644 changelog.d/10665.misc
(limited to 'synapse')
diff --git a/changelog.d/10665.misc b/changelog.d/10665.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10665.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 3ebe401861..6c24b96c54 100644
--- a/synapse/rest/client/account_validity.py
+++ b/synapse/rest/client/account_validity.py
@@ -13,24 +13,27 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_html
-from synapse.http.servlet import RestServlet
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer, respond_with_html
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
hs.config.account_validity.account_validity_invalid_token_template
)
- async def on_GET(self, request):
- if b"token" not in request.args:
- raise SynapseError(400, "Missing renewal token")
- renewal_token = request.args[b"token"][0]
+ async def on_GET(self, request: Request) -> None:
+ renewal_token = parse_string(request, "token", required=True)
(
token_valid,
token_stale,
expiration_ts,
- ) = await self.account_activity_handler.renew_account(
- renewal_token.decode("utf8")
- )
+ ) = await self.account_activity_handler.renew_account(renewal_token)
if token_valid:
status_code = 200
@@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
hs.config.account_validity.account_validity_renew_by_email_enabled
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
@@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 093549512e..65b3b5ce2c 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -15,6 +15,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@@ -75,5 +76,5 @@ class CapabilitiesRestServlet(RestServlet):
return 200, response
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ffa075c8e5..ee247e3d1e 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.errors import (
AuthError,
@@ -22,14 +24,19 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
@@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
- res = await self.directory_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias_obj)
return 200, res
- async def on_PUT(self, request, room_alias):
- room_alias = RoomAlias.from_string(room_alias)
+ async def on_PUT(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
@@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
)
logger.debug("Got content: %s", content)
- logger.debug("Got room name: %s", room_alias.to_string())
+ logger.debug("Got room name: %s", room_alias_obj.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
@@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias, room_id, servers
+ requester, room_alias_obj, room_id, servers
)
return 200, {}
- async def on_DELETE(self, request, room_alias):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_alias: str
+ ) -> Tuple[int, JsonDict]:
+ room_alias_obj = RoomAlias.from_string(room_alias)
+
try:
service = self.auth.get_appservice_by_req(request)
- room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
- service, room_alias
+ service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string(),
+ room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
@@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
user = requester.user
- room_alias = RoomAlias.from_string(room_alias)
-
- await self.directory_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias_obj)
logger.info(
- "User %s deleted alias %s", user.to_string(), room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)
return 200, {}
@@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- async def on_PUT(self, request, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
return 200, {}
- async def on_DELETE(self, request, room_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.edit_published_room_list(
@@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
- def on_PUT(self, request, network_id, room_id):
+ async def on_PUT(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- return self._edit(request, network_id, room_id, visibility)
+ return await self._edit(request, network_id, room_id, visibility)
- def on_DELETE(self, request, network_id, room_id):
- return self._edit(request, network_id, room_id, "private")
+ async def on_DELETE(
+ self, request: SynapseRequest, network_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ return await self._edit(request, network_id, room_id, "private")
- async def _edit(self, request, network_id, room_id, visibility):
+ async def _edit(
+ self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 84619c5e41..98604a9388 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import respond_with_html_bytes
+from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user
@@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"You have been unsubscribed"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
@@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 027f8b81fa..43c04fac6f 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -13,27 +13,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6d1b083acb..6a7792e18b 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -13,18 +13,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import stringutils
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
}
Creates a new room and shuts down the old one. Returns the ID of the new room.
-
- Args:
- hs (synapse.server.HomeServer):
"""
PATTERNS = client_patterns(
@@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
"/rooms/(?P[^/]*)/upgrade$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index d2e7f04b40..1d90493eb0 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
releases=(), # This is an unstable feature
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
if not self.user_directory_active:
raise SynapseError(
@@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
return 200, {"joined": list(rooms)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c14f83be18..c88cb9367c 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
@@ -54,12 +63,14 @@ class TagServlet(RestServlet):
"/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, tag):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -70,7 +81,9 @@ class TagServlet(RestServlet):
return 200, {}
- async def on_DELETE(self, request, user_id, room_id, tag):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
@@ -80,6 +93,6 @@ class TagServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)
diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py
index b5c67c9bb6..b895c73acf 100644
--- a/synapse/rest/client/thirdparty.py
+++ b/synapse/rest/client/thirdparty.py
@@ -12,27 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.constants import ThirdPartyEntityKind
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols()
@@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
protocols = await self.appservice_handler.get_3pe_protocols(
@@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- async def on_GET(self, request, protocol):
+ async def on_GET(
+ self, request: SynapseRequest, protocol: str
+ ) -> Tuple[int, List[JsonDict]]:
await self.auth.get_user_by_req(request, allow_guest=True)
- fields = request.args
+ fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment]
fields.pop(b"access_token", None)
results = await self.appservice_handler.query_3pe(
@@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py
index b2f858545c..c8c3b25bd3 100644
--- a/synapse/rest/client/tokenrefresh.py
+++ b/synapse/rest/client/tokenrefresh.py
@@ -12,11 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class TokenRefreshRestServlet(RestServlet):
"""
@@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> None:
raise AuthError(403, "tokenrefresh is no longer supported.")
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 7e8912f0b9..8852811114 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -13,29 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Searches for users in directory
Returns:
@@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index fa2e4e9cba..a1a815cf82 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -17,9 +17,17 @@
import logging
import re
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
in self.config.encryption_enabled_by_default_for_room_presets
)
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
return (
200,
{
@@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index f53020520d..9d46ed3af3 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -15,20 +15,27 @@
import base64
import hashlib
import hmac
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
@@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
VoipRestServlet(hs).register(http_server)
--
cgit 1.5.1
From bd7d398b05aaa18d5b0629153ababeea7539256c Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Mon, 23 Aug 2021 08:14:42 -0400
Subject: Additional type hints for the sync REST servlet. (#10666)
---
changelog.d/10666.misc | 1 +
synapse/handlers/sync.py | 21 +++----
synapse/rest/client/sync.py | 132 +++++++++++++++++++++++++++-----------------
3 files changed, 93 insertions(+), 61 deletions(-)
create mode 100644 changelog.d/10666.misc
(limited to 'synapse')
diff --git a/changelog.d/10666.misc b/changelog.d/10666.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10666.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2203c45dcc..86c3c7f0df 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -30,6 +30,7 @@ from prometheus_client import Counter
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.filtering import FilterCollection
+from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging.context import current_context
@@ -231,7 +232,7 @@ class SyncResult:
"""
next_batch: StreamToken
- presence: List[JsonDict]
+ presence: List[UserPresenceState]
account_data: List[JsonDict]
joined: List[JoinedSyncResult]
invited: List[InvitedSyncResult]
@@ -2177,14 +2178,14 @@ class SyncResultBuilder:
joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response
- presence (list)
- account_data (list)
- joined (list[JoinedSyncResult])
- invited (list[InvitedSyncResult])
- knocked (list[KnockedSyncResult])
- archived (list[ArchivedSyncResult])
- groups (GroupsSyncResult|None)
- to_device (list)
+ presence
+ account_data
+ joined
+ invited
+ knocked
+ archived
+ groups
+ to_device
"""
sync_config: SyncConfig
@@ -2193,7 +2194,7 @@ class SyncResultBuilder:
now_token: StreamToken
joined_room_ids: FrozenSet[str]
- presence: List[JsonDict] = attr.Factory(list)
+ presence: List[UserPresenceState] = attr.Factory(list)
account_data: List[JsonDict] = attr.Factory(list)
joined: List[JoinedSyncResult] = attr.Factory(list)
invited: List[InvitedSyncResult] = attr.Factory(list)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01b3..65c37be3e9 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+ ArchivedSyncResult,
+ InvitedSyncResult,
+ JoinedSyncResult,
+ KnockedSyncResult,
+ SyncConfig,
+ SyncResult,
+)
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
return 200, {}
time_now = self.clock.time_msec()
+ # We know that the the requester has an access token since appservices
+ # cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
@@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
- async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(
+ self,
+ time_now: int,
+ sync_result: SyncResult,
+ access_token_id: Optional[int],
+ filter: FilterCollection,
+ ) -> JsonDict:
logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
logger.debug("building sync response dict")
- response: dict = defaultdict(dict)
+ response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.account_data:
@@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
if archived:
response["rooms"][Membership.LEAVE] = archived
+ # By the time we get here groups is no longer optional.
+ assert sync_result.groups is not None
if sync_result.groups.join:
response["groups"][Membership.JOIN] = sync_result.groups.join
if sync_result.groups.invite:
@@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
return response
@staticmethod
- def encode_presence(events, time_now):
+ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
return {
"events": [
{
@@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
}
async def encode_joined(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[JoinedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the joined rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
- results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the joined rooms list, in our
- response format
+ The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
return joined
- async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(
+ self,
+ rooms: List[InvitedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the invited rooms in a sync result
Args:
- rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
- sync results for rooms this user is invited to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is invited to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_formatter (func[dict]): function to convert from federation format
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, dict[str, object]]: the invited rooms list, in our
- response format
+ The invited rooms list, in our response format
"""
invited = {}
for room in rooms:
@@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[KnockedSyncResult],
time_now: int,
- token_id: int,
+ token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
) -> Dict[str, Dict[str, Any]]:
"""
@@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
return knocked
async def encode_archived(
- self, rooms, time_now, token_id, event_fields, event_formatter
- ):
+ self,
+ rooms: List[ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ event_fields: List[str],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Encode the archived rooms in a sync result
Args:
- rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
- sync results for rooms this user is joined to
- time_now(int): current time - used as a baseline for age
- calculations
- token_id(int): ID of the user's auth token - used for namespacing
+ rooms: list of sync results for rooms this user is joined to
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- event_fields(list): List of event fields to include. If empty,
+ event_fields: List of event fields to include. If empty,
all fields will be returned.
- event_formatter (func[dict]): function to convert from federation format
- to client format
+ event_formatter: function to convert from federation format to client format
Returns:
- dict[str, dict[str, object]]: The invited rooms list, in our
- response format
+ The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
@@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
return joined
async def encode_room(
- self, room, time_now, token_id, joined, only_fields, event_formatter
- ):
+ self,
+ room: Union[JoinedSyncResult, ArchivedSyncResult],
+ time_now: int,
+ token_id: Optional[int],
+ joined: bool,
+ only_fields: Optional[List[str]],
+ event_formatter: Callable[[JsonDict], JsonDict],
+ ) -> JsonDict:
"""
Args:
- room (JoinedSyncResult|ArchivedSyncResult): sync result for a
- single room
- time_now (int): current time - used as a baseline for age
- calculations
- token_id (int): ID of the user's auth token - used for namespacing
+ room: sync result for a single room
+ time_now: current time - used as a baseline for age calculations
+ token_id: ID of the user's auth token - used for namespacing
of transaction IDs
- joined (bool): True if the user is joined to this room - will mean
+ joined: True if the user is joined to this room - will mean
we handle ephemeral events
- only_fields(list): Optional. The list of event fields to include.
- event_formatter (func[dict]): function to convert from federation format
+ only_fields: Optional. The list of event fields to include.
+ event_formatter: function to convert from federation format
to client format
Returns:
- dict[str, object]: the room, encoded in our response format
+ The room, encoded in our response format
"""
def serialize(events):
@@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
account_data = room.account_data
- result = {
+ result: JsonDict = {
"timeline": {
"events": serialized_timeline,
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
}
if joined:
+ assert isinstance(room, JoinedSyncResult)
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
return result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
--
cgit 1.5.1
From 0c1d6f65d7c65efd8491adf4efc2620148e2841a Mon Sep 17 00:00:00 2001
From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com>
Date: Mon, 23 Aug 2021 16:25:33 +0100
Subject: Enforce the max length for per-room display names / avatar URLs.
(#10654)
To match the maximum lengths allowed for profile data.
---
changelog.d/10654.bugfix | 1 +
synapse/handlers/room_member.py | 17 ++++++++++++++++-
2 files changed, 17 insertions(+), 1 deletion(-)
create mode 100644 changelog.d/10654.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10654.bugfix b/changelog.d/10654.bugfix
new file mode 100644
index 0000000000..b0bd78453f
--- /dev/null
+++ b/changelog.d/10654.bugfix
@@ -0,0 +1 @@
+Enforce the maximum length for per-room display names and avatar URLs.
\ No newline at end of file
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba13196218..401b84aad1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import (
JsonDict,
Requester,
@@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
- self.member_linearizer = Linearizer(name="member")
+ self.member_linearizer: Linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content.pop("displayname", None)
content.pop("avatar_url", None)
+ if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
+ raise SynapseError(
+ 400,
+ f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
+ raise SynapseError(
+ 400,
+ f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
+ errcode=Codes.BAD_JSON,
+ )
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
--
cgit 1.5.1
From 15db8b7c7f13f33ca49104183e0642892c3b83f1 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Tue, 24 Aug 2021 10:17:51 +0100
Subject: Correctly initialise the `synapse_user_logins` metric. (#10677)
Fix a bug where the prometheus metrics for SSO logins wouldn't be initialised
until the first user logged in with a given auth provider.
---
changelog.d/10677.bugfix | 1 +
synapse/handlers/register.py | 18 ++++++++++++++++++
synapse/handlers/sso.py | 2 ++
synapse/rest/client/login.py | 29 +++++++++++++++++++++++------
4 files changed, 44 insertions(+), 6 deletions(-)
create mode 100644 changelog.d/10677.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10677.bugfix b/changelog.d/10677.bugfix
new file mode 100644
index 0000000000..9964afaaee
--- /dev/null
+++ b/changelog.d/10677.bugfix
@@ -0,0 +1 @@
+Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf614136e..0ed59d757b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter(
)
+def init_counters_for_auth_provider(auth_provider_id: str) -> None:
+ """Ensure the prometheus counters for the given auth provider are initialised
+
+ This fixes a problem where the counters are not reported for a given auth provider
+ until the user first logs in/registers.
+ """
+ for is_guest in (True, False):
+ login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
+ for shadow_banned in (True, False):
+ registration_counter.labels(
+ guest=is_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=auth_provider_id,
+ )
+
+
class LoginDict(TypedDict):
device_id: str
access_token: str
@@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime
+ init_counters_for_auth_provider("")
+
async def check_username(
self,
localpart: str,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a685c..0e6ebb574e 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.register import init_counters_for_auth_provider
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html, respond_with_redirect
@@ -213,6 +214,7 @@ class SsoHandler:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
+ init_counters_for_auth_provider(p_id)
def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
"""Get the configured identity providers"""
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967b7..11d07776b2 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_account.burst_count,
)
+ # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+ # The reason for this is to ensure that the auth_provider_ids are registered
+ # with SsoHandler, which in turn ensures that the login/registration prometheus
+ # counters are initialised for the auth_provider_ids.
+ _load_sso_handlers(hs)
+
def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
@@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
# register themselves with the main SSOHandler.
- if hs.config.cas_enabled:
- hs.get_cas_handler()
- if hs.config.saml2_enabled:
- hs.get_saml_handler()
- if hs.config.oidc_enabled:
- hs.get_oidc_handler()
+ _load_sso_handlers(hs)
self._sso_handler = hs.get_sso_handler()
self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self._public_baseurl = hs.config.public_baseurl
@@ -598,3 +599,19 @@ def register_servlets(hs, http_server):
SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer"):
+ """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+ This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+ with the main SsoHandler.
+
+ It's safe to call this multiple times.
+ """
+ if hs.config.cas.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc.oidc_enabled:
+ hs.get_oidc_handler()
--
cgit 1.5.1
From d12ba52f178982ecb47207471bee14472f9597b6 Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Tue, 24 Aug 2021 08:14:03 -0400
Subject: Persist room hierarchy pagination sessions to the database. (#10613)
---
changelog.d/10613.feature | 1 +
mypy.ini | 1 +
synapse/app/generic_worker.py | 2 +
synapse/handlers/room_summary.py | 76 +++++------
synapse/storage/databases/main/__init__.py | 2 +
synapse/storage/databases/main/session.py | 145 +++++++++++++++++++++
.../schema/main/delta/62/02session_store.sql | 23 ++++
7 files changed, 212 insertions(+), 38 deletions(-)
create mode 100644 changelog.d/10613.feature
create mode 100644 synapse/storage/databases/main/session.py
create mode 100644 synapse/storage/schema/main/delta/62/02session_store.sql
(limited to 'synapse')
diff --git a/changelog.d/10613.feature b/changelog.d/10613.feature
new file mode 100644
index 0000000000..ffc4e4289c
--- /dev/null
+++ b/changelog.d/10613.feature
@@ -0,0 +1 @@
+Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).
diff --git a/mypy.ini b/mypy.ini
index b17872211e..745e6b78eb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -57,6 +57,7 @@ files =
synapse/storage/databases/main/keys.py,
synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py,
+ synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fd2626dbe1..9b71dd75e6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -118,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
from synapse.storage.databases.main.presence import PresenceStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -253,6 +254,7 @@ class GenericWorkerSlavedStore(
SearchStore,
TransactionWorkerStore,
LockStore,
+ SessionStore,
BaseSlavedStore,
):
pass
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ac6cfc0da9..906985c754 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,12 +28,11 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -76,6 +75,9 @@ class _PaginationSession:
class RoomSummaryHandler:
+ # A unique key used for pagination sessions for the room hierarchy endpoint.
+ _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
+
# The time a pagination session remains valid for.
_PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
@@ -87,12 +89,6 @@ class RoomSummaryHandler:
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
- # A map of query information to the current pagination state.
- #
- # TODO Allow for multiple workers to share this data.
- # TODO Expire pagination tokens.
- self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
-
# If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[
@@ -102,21 +98,6 @@ class RoomSummaryHandler:
"get_room_hierarchy",
)
- def _expire_pagination_sessions(self):
- """Expire pagination session which are old."""
- expire_before = (
- self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
- )
- to_expire = []
-
- for key, value in self._pagination_sessions.items():
- if value.creation_time_ms < expire_before:
- to_expire.append(key)
-
- for key in to_expire:
- logger.debug("Expiring pagination session id %s", key)
- del self._pagination_sessions[key]
-
async def get_space_summary(
self,
requester: str,
@@ -327,18 +308,29 @@ class RoomSummaryHandler:
# If this is continuing a previous session, pull the persisted data.
if from_token:
- self._expire_pagination_sessions()
+ try:
+ pagination_session = await self._store.get_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ session_id=from_token,
+ )
+ except StoreError:
+ raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, from_token
- )
- if pagination_key not in self._pagination_sessions:
+ # If the requester, room ID, suggested-only, or max depth were modified
+ # the session is invalid.
+ if (
+ requester != pagination_session["requester"]
+ or requested_room_id != pagination_session["room_id"]
+ or suggested_only != pagination_session["suggested_only"]
+ or max_depth != pagination_session["max_depth"]
+ ):
raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
# Load the previous state.
- pagination_session = self._pagination_sessions[pagination_key]
- room_queue = pagination_session.room_queue
- processed_rooms = pagination_session.processed_rooms
+ room_queue = [
+ _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
+ ]
+ processed_rooms = set(pagination_session["processed_rooms"])
else:
# The queue of rooms to process, the next room is last on the stack.
room_queue = [_RoomQueueEntry(requested_room_id, ())]
@@ -456,13 +448,21 @@ class RoomSummaryHandler:
# If there's additional data, generate a pagination token (and persist state).
if room_queue:
- next_batch = random_string(24)
- result["next_batch"] = next_batch
- pagination_key = _PaginationKey(
- requested_room_id, suggested_only, max_depth, next_batch
- )
- self._pagination_sessions[pagination_key] = _PaginationSession(
- self._clock.time_msec(), room_queue, processed_rooms
+ result["next_batch"] = await self._store.create_session(
+ session_type=self._PAGINATION_SESSION_TYPE,
+ value={
+ # Information which must be identical across pagination.
+ "requester": requester,
+ "room_id": requested_room_id,
+ "suggested_only": suggested_only,
+ "max_depth": max_depth,
+ # The stored state.
+ "room_queue": [
+ attr.astuple(room_entry) for room_entry in room_queue
+ ],
+ "processed_rooms": list(processed_rooms),
+ },
+ expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
)
return result
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 01b918e12e..00a644e8f7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
+from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
+ SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 0000000000..172f27d109
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+ """
+ A store for generic session data.
+
+ Each type of session should provide a unique type (to separate sessions).
+
+ Sessions are automatically removed when they expire.
+ """
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ # Create a background job for culling expired sessions.
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+ async def create_session(
+ self, session_type: str, value: JsonDict, expiry_ms: int
+ ) -> str:
+ """
+ Creates a new pagination session for the room hierarchy endpoint.
+
+ Args:
+ session_type: The type for this session.
+ value: The value to store.
+ expiry_ms: How long before an item is evicted from the cache
+ in milliseconds. Default is 0, indicating items never get
+ evicted based on time.
+
+ Returns:
+ The newly created session ID.
+
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db_pool.simple_insert(
+ table="sessions",
+ values={
+ "session_id": session_id,
+ "session_type": session_type,
+ "value": json_encoder.encode(value),
+ "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+ },
+ desc="create_session",
+ )
+
+ return session_id
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+ """
+ Retrieve data stored with create_session
+
+ Args:
+ session_type: The type for this session.
+ session_id: The session ID returned from create_session.
+
+ Raises:
+ StoreError if the session cannot be found.
+ """
+
+ def _get_session(
+ txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+ ) -> JsonDict:
+ # This includes the expiry time since items are only periodically
+ # deleted, not upon expiry.
+ select_sql = """
+ SELECT value FROM sessions WHERE
+ session_type = ? AND session_id = ? AND expiry_time_ms > ?
+ """
+ txn.execute(select_sql, [session_type, session_id, ts])
+ row = txn.fetchone()
+
+ if not row:
+ raise StoreError(404, "No session")
+
+ return db_to_json(row[0])
+
+ return await self.db_pool.runInteraction(
+ "get_session",
+ _get_session,
+ session_type,
+ session_id,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("delete_expired_sessions")
+ async def _delete_expired_sessions(self) -> None:
+ """Remove sessions with expiry dates that have passed."""
+
+ def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "delete_expired_sessions",
+ _delete_expired_sessions_txn,
+ self._clock.time_msec(),
+ )
diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql
new file mode 100644
index 0000000000..535fb34c10
--- /dev/null
+++ b/synapse/storage/schema/main/delta/62/02session_store.sql
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS sessions(
+ session_type TEXT NOT NULL, -- The unique key for this type of session.
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ value TEXT NOT NULL, -- A JSON dictionary to persist.
+ expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds).
+ UNIQUE (session_type, session_id)
+);
--
cgit 1.5.1
From 7367473f965fed1160cb8633de341c5833e5b662 Mon Sep 17 00:00:00 2001
From: Sean
Date: Wed, 25 Aug 2021 10:51:08 +0100
Subject: Fix error when selecting between thumbnails with the same quality
(#10684)
Fixes #10318
---
changelog.d/10684.bugfix | 1 +
synapse/rest/media/v1/thumbnail_resource.py | 26 ++++++++++++-------
tests/rest/media/v1/test_media_storage.py | 39 ++++++++++++++++++++++++++++-
3 files changed, 56 insertions(+), 10 deletions(-)
create mode 100644 changelog.d/10684.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10684.bugfix b/changelog.d/10684.bugfix
new file mode 100644
index 0000000000..311b17601a
--- /dev/null
+++ b/changelog.d/10684.bugfix
@@ -0,0 +1 @@
+Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a029d426f0..12bd745cb2 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -15,7 +15,7 @@
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
@@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource):
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list = []
+ crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- crop_info_list2 = []
+ crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
@@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource):
info,
)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if crop_info_list:
- thumbnail_info = min(crop_info_list)[-1]
+ thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
- thumbnail_info = min(crop_info_list2)[-1]
+ thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
- info_list = []
+ info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = []
# Other thumbnails.
- info_list2 = []
+ info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
@@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource):
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info dictionary and pick the thumbnail that appears earlier
+ # in the list of candidates.
if info_list:
- thumbnail_info = min(info_list)[-1]
+ thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
- thumbnail_info = min(info_list2)[-1]
+ thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 6085444b9d..2f7eebfe69 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -21,7 +21,7 @@ from unittest.mock import Mock
from urllib import parse
import attr
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
from PIL import Image as Image
from twisted.internet import defer
@@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase):
},
)
+ @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
+ def test_same_quality(self, method, desired_size):
+ """Test that choosing between thumbnails with the same quality rating succeeds.
+
+ We are not particular about which thumbnail is chosen."""
+ self.assertIsNotNone(
+ self.thumbnail_resource._select_thumbnail(
+ desired_width=desired_size,
+ desired_height=desired_size,
+ desired_method=method,
+ desired_type=self.test_image.content_type,
+ # Provide two identical thumbnails which are guaranteed to have the same
+ # quality rating.
+ thumbnail_infos=[
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ },
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ },
+ ],
+ file_id=f"image{self.test_image.extension}",
+ url_cache=None,
+ server_name=None,
+ )
+ )
+
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
--
cgit 1.5.1
From b45cc1530b1438b8bfd9c09f179c7338e85ac083 Mon Sep 17 00:00:00 2001
From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
Date: Wed, 25 Aug 2021 17:00:44 +0100
Subject: Make a note to leave a summary when one is bumping the schema version
(#10621)
I found this easy to miss (and evidently, it looks like it was missed for schema version 62).
---
changelog.d/10621.misc | 1 +
synapse/storage/schema/__init__.py | 2 ++
2 files changed, 3 insertions(+)
create mode 100644 changelog.d/10621.misc
(limited to 'synapse')
diff --git a/changelog.d/10621.misc b/changelog.d/10621.misc
new file mode 100644
index 0000000000..b8de2e1911
--- /dev/null
+++ b/changelog.d/10621.misc
@@ -0,0 +1 @@
+Add a comment asking developers to leave a reason when bumping the database schema version.
\ No newline at end of file
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index a5bc0ee8a5..af9cc69949 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# When updating these values, please leave a short summary of the changes below.
+
SCHEMA_VERSION = 63
"""Represents the expectations made by the codebase about the database schema
--
cgit 1.5.1
From 5548fe097881b543cba37c7cda27ff7efe55025d Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Thu, 26 Aug 2021 07:16:53 -0400
Subject: Cache the result of fetching the room hierarchy over federation.
(#10647)
---
changelog.d/10647.misc | 1 +
synapse/federation/federation_client.py | 106 ++++++++++++++++++++------------
2 files changed, 67 insertions(+), 40 deletions(-)
create mode 100644 changelog.d/10647.misc
(limited to 'synapse')
diff --git a/changelog.d/10647.misc b/changelog.d/10647.misc
new file mode 100644
index 0000000000..4407a9030d
--- /dev/null
+++ b/changelog.d/10647.misc
@@ -0,0 +1 @@
+Improve the performance of the `/hierarchy` API (from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)) by caching responses received over federation.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 44d9e8a5c7..1416abd0fb 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -111,6 +111,23 @@ class FederationClient(FederationBase):
reset_expiry_on_get=False,
)
+ # A cache for fetching the room hierarchy over federation.
+ #
+ # Some stale data over federation is OK, but must be refreshed
+ # periodically since the local server is in the room.
+ #
+ # It is a map of (room ID, suggested-only) -> the response of
+ # get_room_hierarchy.
+ self._get_room_hierarchy_cache: ExpiringCache[
+ Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]
+ ] = ExpiringCache(
+ cache_name="get_room_hierarchy_cache",
+ clock=self._clock,
+ max_len=1000,
+ expiry_ms=5 * 60 * 1000,
+ reset_expiry_on_get=False,
+ )
+
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
@@ -1324,6 +1341,10 @@ class FederationClient(FederationBase):
remote servers
"""
+ cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only))
+ if cached_result:
+ return cached_result
+
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]:
@@ -1370,58 +1391,63 @@ class FederationClient(FederationBase):
return room, children, inaccessible_children
try:
- return await self._try_destination_list(
+ result = await self._try_destination_list(
"fetch room hierarchy",
destinations,
send_request,
failover_on_unknown_endpoint=True,
)
except SynapseError as e:
+ # If an unexpected error occurred, re-raise it.
+ if e.code != 502:
+ raise
+
# Fallback to the old federation API and translate the results if
# no servers implement the new API.
#
# The algorithm below is a bit inefficient as it only attempts to
- # get information for the requested room, but the legacy API may
+ # parse information for the requested room, but the legacy API may
# return additional layers.
- if e.code == 502:
- legacy_result = await self.get_space_summary(
- destinations,
- room_id,
- suggested_only,
- max_rooms_per_space=None,
- exclude_rooms=[],
- )
+ legacy_result = await self.get_space_summary(
+ destinations,
+ room_id,
+ suggested_only,
+ max_rooms_per_space=None,
+ exclude_rooms=[],
+ )
- # Find the requested room in the response (and remove it).
- for _i, room in enumerate(legacy_result.rooms):
- if room.get("room_id") == room_id:
- break
- else:
- # The requested room was not returned, nothing we can do.
- raise
- requested_room = legacy_result.rooms.pop(_i)
-
- # Find any children events of the requested room.
- children_events = []
- children_room_ids = set()
- for event in legacy_result.events:
- if event.room_id == room_id:
- children_events.append(event.data)
- children_room_ids.add(event.state_key)
- # And add them under the requested room.
- requested_room["children_state"] = children_events
-
- # Find the children rooms.
- children = []
- for room in legacy_result.rooms:
- if room.get("room_id") in children_room_ids:
- children.append(room)
-
- # It isn't clear from the response whether some of the rooms are
- # not accessible.
- return requested_room, children, ()
-
- raise
+ # Find the requested room in the response (and remove it).
+ for _i, room in enumerate(legacy_result.rooms):
+ if room.get("room_id") == room_id:
+ break
+ else:
+ # The requested room was not returned, nothing we can do.
+ raise
+ requested_room = legacy_result.rooms.pop(_i)
+
+ # Find any children events of the requested room.
+ children_events = []
+ children_room_ids = set()
+ for event in legacy_result.events:
+ if event.room_id == room_id:
+ children_events.append(event.data)
+ children_room_ids.add(event.state_key)
+ # And add them under the requested room.
+ requested_room["children_state"] = children_events
+
+ # Find the children rooms.
+ children = []
+ for room in legacy_result.rooms:
+ if room.get("room_id") in children_room_ids:
+ children.append(room)
+
+ # It isn't clear from the response whether some of the rooms are
+ # not accessible.
+ result = (requested_room, children, ())
+
+ # Cache the result to avoid fetching data over federation every time.
+ self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
+ return result
@attr.s(frozen=True, slots=True, auto_attribs=True)
--
cgit 1.5.1
From 1aa0dad02187c3b972187f5952cfbce336b0ca5c Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Thu, 26 Aug 2021 07:53:52 -0400
Subject: Additional type hints for REST servlets (part 2). (#10674)
Applies the changes from #10665 to additional modules.
---
changelog.d/10674.misc | 1 +
synapse/handlers/presence.py | 5 +++
synapse/rest/client/auth.py | 11 ++++---
synapse/rest/client/devices.py | 48 +++++++++++++++-------------
synapse/rest/client/events.py | 38 ++++++++++++++---------
synapse/rest/client/filter.py | 26 +++++++++++-----
synapse/rest/client/groups.py | 3 +-
synapse/rest/client/initial_sync.py | 16 +++++++---
synapse/rest/client/keys.py | 57 +++++++++++++---------------------
synapse/rest/client/knock.py | 3 +-
synapse/rest/client/login.py | 21 ++++++-------
synapse/rest/client/logout.py | 17 +++++++---
synapse/rest/client/notifications.py | 13 ++++++--
synapse/rest/client/openid.py | 16 +++++++---
synapse/rest/client/password_policy.py | 18 ++++++-----
synapse/rest/client/presence.py | 24 +++++++++-----
synapse/rest/client/profile.py | 37 ++++++++++++++++------
17 files changed, 216 insertions(+), 138 deletions(-)
create mode 100644 changelog.d/10674.misc
(limited to 'synapse')
diff --git a/changelog.d/10674.misc b/changelog.d/10674.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10674.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7ca14e1d84..4418d63df7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC):
# otherwise would not do).
await self.set_state(UserID.from_string(user_id), state, force_notify=True)
+ async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
+ raise NotImplementedError(
+ "Attempting to check presence on a non-presence worker."
+ )
+
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 91800c0278..df8cc4ac7a 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -15,11 +15,14 @@
import logging
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.server import respond_with_html
+from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from ._base import client_patterns
@@ -49,7 +52,7 @@ class AuthRestServlet(RestServlet):
self.registration_token_template = hs.config.registration_token_template
self.success_template = hs.config.fallback_success_template
- async def on_GET(self, request, stagetype):
+ async def on_GET(self, request: SynapseRequest, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -88,7 +91,7 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html)
return None
- async def on_POST(self, request, stagetype):
+ async def on_POST(self, request: Request, stagetype: str) -> None:
session = parse_string(request, "session")
if not session:
@@ -172,5 +175,5 @@ class AuthRestServlet(RestServlet):
return None
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8b9674db06..25bc3c8f47 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,34 +14,36 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
@@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P[^/]*)$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- async def on_GET(self, request, device_id):
+ async def on_GET(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
@@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet):
return 200, device
@interactive_auth_handler
- async def on_DELETE(self, request, device_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
@@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
- async def on_PUT(self, request, device_id):
+ async def on_PUT(
+ self, request: SynapseRequest, device_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
@@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet):
PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
dehydrated_device = await self.device_handler.get_dehydrated_device(
requester.user.to_string()
@@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet):
else:
raise errors.NotFoundError("No dehydrated device available")
- async def on_PUT(self, request: SynapseRequest):
+ async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request)
requester = await self.auth.get_user_by_req(request)
@@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet):
"/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request)
@@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet):
return (200, result)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 52bb579cfd..13b72a045a 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -14,11 +14,18 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.server import HttpServer
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
- room_id = None
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
- if b"room_id" not in request.args:
+ if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
- if b"room_id" in request.args:
- room_id = request.args[b"room_id"][0].decode("ascii")
+ room_id = parse_string(request, "room_id")
pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
- if b"timeout" in request.args:
+ if b"timeout" in args:
try:
- timeout = int(request.args[b"timeout"][0])
+ timeout = int(args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
- as_client_event = b"raw" not in request.args
+ as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
@@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, event_id: str
+ ) -> Tuple[int, Union[str, JsonDict]]:
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ result = await self._event_serializer.serialize_event(event, time_now)
+ return 200, result
else:
return 404, "Event not found."
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 411667a9c8..6ed60c7418 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -13,26 +13,34 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
from ._base import client_patterns, set_timeline_upper_limit
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_GET(self, request, user_id, filter_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, filter_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Can only get filters for local users")
try:
- filter_id = int(filter_id)
+ filter_id_int = int(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
filter_collection = await self.filtering.get_user_filter(
- user_localpart=target_user.localpart, filter_id=filter_id
+ user_localpart=target_user.localpart, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
@@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P[^/]*)/filter")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet):
return 200, {"filter_id": str(filter_id)}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index 6285680c00..c3667ff8aa 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -26,6 +26,7 @@ from synapse.api.constants import (
)
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index 12ba0e91db..49b1037b28 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -12,25 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Dict, List, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- as_client_event = b"raw" not in request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+ as_client_event = b"raw" not in args
pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
@@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet):
return 200, content
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
InitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 012491f597..7281b2ee29 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,20 +15,25 @@
# limitations under the License.
import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.types import StreamToken
+from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -60,18 +65,16 @@ class KeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
- async def on_POST(self, request, device_id):
+ async def on_POST(
+ self, request: SynapseRequest, device_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -149,16 +152,12 @@ class KeyQueryServlet(RestServlet):
PATTERNS = client_patterns("/keys/query$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
device_id = requester.device_id
@@ -195,17 +194,13 @@ class KeyChangesServlet(RestServlet):
PATTERNS = client_patterns("/keys/changes$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from", required=True)
@@ -245,12 +240,12 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
@@ -269,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -281,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -329,16 +320,12 @@ class SignaturesUploadServlet(RestServlet):
PATTERNS = client_patterns("/keys/signatures/upload$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -349,7 +336,7 @@ class SignaturesUploadServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 7d1bc40658..68fb08d0ba 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -19,6 +19,7 @@ from twisted.web.server import Request
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KnockRoomAliasServlet(hs).register(http_server)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 11d07776b2..4be502a77b 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
@@ -110,7 +110,7 @@ class LoginRestServlet(RestServlet):
# counters are initialised for the auth_provider_ids.
_load_sso_handlers(hs)
- def on_GET(self, request: SynapseRequest):
+ def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -157,7 +157,7 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- async def on_POST(self, request: SynapseRequest):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
login_submission = parse_json_object_from_request(request)
if self._msc2918_enabled:
@@ -217,7 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
- ):
+ ) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -467,10 +467,7 @@ class RefreshTokenServlet(RestServlet):
self._clock = hs.get_clock()
self.access_token_lifetime = hs.config.access_token_lifetime
- async def on_POST(
- self,
- request: SynapseRequest,
- ):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
refresh_submission = parse_json_object_from_request(request)
assert_params_in_dict(refresh_submission, ["refresh_token"])
@@ -570,7 +567,7 @@ class SsoRedirectServlet(RestServlet):
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._cas_handler = hs.get_cas_handler()
@@ -592,7 +589,7 @@ class CasTicketServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
if hs.config.access_token_lifetime is not None:
RefreshTokenServlet(hs).register(http_server)
@@ -601,7 +598,7 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
-def _load_sso_handlers(hs: "HomeServer"):
+def _load_sso_handlers(hs: "HomeServer") -> None:
"""Ensure that the SSO handlers are loaded, if they are enabled by configuration.
This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 6055cac2bd..193a6951b9 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -13,9 +13,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -23,13 +30,13 @@ logger = logging.getLogger(__name__)
class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
@@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet):
class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
@@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LogoutRestServlet(hs).register(http_server)
LogoutAllRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 0ede643c2d..d1d8a984c6 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -13,26 +13,33 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.events.utils import format_event_for_client_v2_without_room_id
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet):
return 200, {"notifications": returned_push_actions, "next_token": next_token}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NotificationsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py
index e8d2673819..4dda6dce4b 100644
--- a/synapse/rest/client/openid.py
+++ b/synapse/rest/client/openid.py
@@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
IdTokenServlet(hs).register(http_server)
diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py
index a83927aee6..6d64efb165 100644
--- a/synapse/rest/client/password_policy.py
+++ b/synapse/rest/client/password_policy.py
@@ -13,28 +13,32 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class PasswordPolicyServlet(RestServlet):
PATTERNS = client_patterns("/password_policy$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
- def on_GET(self, request):
+ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.enabled or not self.policy:
return (200, {})
@@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet):
return (200, policy)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 6c27e5faf9..94dd4fe2f4 100644
--- a/synapse/rest/client/presence.py
+++ b/synapse/rest/client/presence.py
@@ -15,12 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/
"""
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -28,7 +34,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
@@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet):
self._use_presence = hs.config.server.use_presence
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
+ result = format_user_presence_state(
state, self.clock.time_msec(), include_user_id=False
)
- return 200, state
+ return 200, result
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
@@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PresenceStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index 5463ed2c4f..d0f20de569 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -14,22 +14,31 @@
""" This module contains REST servlets to do with profile: /profile/ """
+from typing import TYPE_CHECKING, Tuple
+
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret
- async def on_PUT(self, request, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
@@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
@@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet):
return 200, ret
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
--
cgit 1.5.1
From ad17fbd20eb2dd9fb10a3d02ab1b69e9a0d5b50c Mon Sep 17 00:00:00 2001
From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com>
Date: Thu, 26 Aug 2021 13:53:57 +0100
Subject: Remove pushers when deleting 3pid from account (#10581)
When a user deletes an email from their account it will
now also remove all pushers for that email and that user
(even if these pushers were created by a different client)
---
CHANGES.md | 2 +
changelog.d/10581.bugfix | 1 +
docs/upgrade.md | 5 ++
synapse/handlers/auth.py | 5 +-
synapse/storage/databases/main/pusher.py | 72 ++++++++++++++++++++++
.../delta/63/02delete_unlinked_email_pushers.sql | 20 ++++++
tests/push/test_email.py | 39 ++++++++++++
7 files changed, 143 insertions(+), 1 deletion(-)
create mode 100644 changelog.d/10581.bugfix
create mode 100644 synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
(limited to 'synapse')
diff --git a/CHANGES.md b/CHANGES.md
index f8da8771aa..24f3d53a6d 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,5 @@
+Users will stop receiving message updates via email for addresses that were previously linked to their account
+
Synapse 1.41.0 (2021-08-24)
===========================
diff --git a/changelog.d/10581.bugfix b/changelog.d/10581.bugfix
new file mode 100644
index 0000000000..15c7da4497
--- /dev/null
+++ b/changelog.d/10581.bugfix
@@ -0,0 +1 @@
+Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted.
\ No newline at end of file
diff --git a/docs/upgrade.md b/docs/upgrade.md
index 6d4b8cb48e..dcf0a7db5b 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -107,6 +107,11 @@ This may affect you if you make use of custom HTML templates for the
The template is now provided an `error` variable if the authentication
process failed. See the default templates linked above for an example.
+# Upgrading to v1.42.0
+
+## Removal of out-of-date email pushers
+Users will stop receiving message updates via email for addresses that were
+once, but not still, linked to their account.
# Upgrading to v1.41.0
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 98d3d2d97f..34725324a6 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1464,6 +1464,10 @@ class AuthHandler(BaseHandler):
)
await self.store.user_delete_threepid(user_id, medium, address)
+ if medium == "email":
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id="m.email", pushkey=address, user_id=user_id
+ )
return result
async def hash(self, password: str) -> str:
@@ -1732,7 +1736,6 @@ class AuthHandler(BaseHandler):
@attr.s(slots=True)
class MacaroonGenerator:
-
hs = attr.ib()
def generate_guest_access_token(self, user_id: str) -> str:
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086d4..e47caa2125 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_stale_pushers,
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deleted_email_pushers",
+ self._remove_deleted_email_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -388,6 +393,73 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
+ async def _remove_deleted_email_pushers(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """A background update that deletes all pushers for deleted email addresses.
+
+ In previous versions of synapse, when users deleted their email address, it didn't
+ also delete all the pushers for that email address. This background update removes
+ those to prevent unwanted emails. This should only need to be run once (when users
+ upgrade to v1.42.0
+
+ Args:
+ progress: dict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, p.user_name, p.app_id, p.pushkey
+ FROM pushers AS p
+ LEFT JOIN user_threepids AS t
+ ON t.user_id = p.user_name
+ AND t.medium = 'email'
+ AND t.address = p.pushkey
+ WHERE t.user_id is NULL
+ AND p.app_id = 'm.email'
+ AND p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+
+ last = None
+ num_deleted = 0
+ for row in txn:
+ last = row[0]
+ num_deleted += 1
+ self.db_pool.simple_delete_txn(
+ txn,
+ "pushers",
+ {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
+ )
+
+ if last is not None:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deleted_email_pushers", {"last_pusher": last}
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_email_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deleted_email_pushers"
+ )
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
new file mode 100644
index 0000000000..611c4b95cf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We may not have deleted all pushers for emails that are no longer linked
+-- to an account, so we set up a background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6302, 'remove_deleted_email_pushers', '{}');
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e0a3342088..eea07485a0 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -125,6 +125,8 @@ class EmailPusherTests(HomeserverTestCase):
)
)
+ self.auth_handler = hs.get_auth_handler()
+
def test_need_validated_email(self):
"""Test that we can only add an email pusher if the user has validated
their email.
@@ -305,6 +307,43 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about that message
self._check_for_mail()
+ def test_no_email_sent_after_removed(self):
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room,
+ src=self.user_id,
+ tok=self.access_token,
+ targ=self.others[0].id,
+ )
+ self.helper.join(
+ room=room,
+ user=self.others[0].id,
+ tok=self.others[0].token,
+ )
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
+ # disassociate the user's email address
+ self.get_success(
+ self.auth_handler.delete_threepid(
+ user_id=self.user_id,
+ medium="email",
+ address="a@example.com",
+ )
+ )
+
+ # check that the pusher for that email address has been deleted
+ pushers = self.get_success(
+ self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+ )
+ pushers = list(pushers)
+ self.assertEqual(len(pushers), 0)
+
def _check_for_mail(self):
"""Check that the user receives an email notification"""
--
cgit 1.5.1
From 40f619eaa54d2391deccec473fc0f655c379e766 Mon Sep 17 00:00:00 2001
From: Aaron Raimist
Date: Thu, 26 Aug 2021 11:07:58 -0500
Subject: Validate new m.room.power_levels events (#10232)
Signed-off-by: Aaron Raimist
---
changelog.d/10232.bugfix | 1 +
synapse/events/utils.py | 5 ++-
synapse/events/validator.py | 77 ++++++++++++++++++++++++++++++++-
synapse/python_dependencies.py | 3 +-
tests/rest/client/test_power_levels.py | 78 ++++++++++++++++++++++++++++++++++
5 files changed, 160 insertions(+), 4 deletions(-)
create mode 100644 changelog.d/10232.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10232.bugfix b/changelog.d/10232.bugfix
new file mode 100644
index 0000000000..7be72271e0
--- /dev/null
+++ b/changelog.d/10232.bugfix
@@ -0,0 +1 @@
+Validate new `m.room.power_levels` events. Contributed by @aaronraimist.
\ No newline at end of file
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index b6da2f60af..738a151cef 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -32,6 +32,9 @@ from . import EventBase
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r"(? EventBase:
"""Returns a pruned version of the given event, which removes all keys we
@@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any):
* NaN, Infinity, -Infinity
"""
if isinstance(value, int):
- if value <= -(2 ** 53) or 2 ** 53 <= value:
+ if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
elif isinstance(value, float):
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index fa6987d7cb..33954b4f62 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections.abc
from typing import Union
+import jsonschema
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.events.utils import validate_canonicaljson
+from synapse.events.utils import (
+ CANONICALJSON_MAX_INT,
+ CANONICALJSON_MIN_INT,
+ validate_canonicaljson,
+)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
@@ -87,6 +93,29 @@ class EventValidator:
400, "Can't create an ACL event that denies the local server"
)
+ if event.type == EventTypes.PowerLevels:
+ try:
+ jsonschema.validate(
+ instance=event.content,
+ schema=POWER_LEVELS_SCHEMA,
+ cls=plValidator,
+ )
+ except jsonschema.ValidationError as e:
+ if e.path:
+ # example: "users_default": '0' is not of type 'integer'
+ message = '"' + e.path[-1] + '": ' + e.message # noqa: B306
+ # jsonschema.ValidationError.message is a valid attribute
+ else:
+ # example: '0' is not of type 'integer'
+ message = e.message # noqa: B306
+ # jsonschema.ValidationError.message is a valid attribute
+
+ raise SynapseError(
+ code=400,
+ msg=message,
+ errcode=Codes.BAD_JSON,
+ )
+
def _validate_retention(self, event: EventBase):
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
@@ -185,3 +214,47 @@ class EventValidator:
def _ensure_state_event(self, event):
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))
+
+
+POWER_LEVELS_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "ban": {"$ref": "#/definitions/int"},
+ "events": {"$ref": "#/definitions/objectOfInts"},
+ "events_default": {"$ref": "#/definitions/int"},
+ "invite": {"$ref": "#/definitions/int"},
+ "kick": {"$ref": "#/definitions/int"},
+ "notifications": {"$ref": "#/definitions/objectOfInts"},
+ "redact": {"$ref": "#/definitions/int"},
+ "state_default": {"$ref": "#/definitions/int"},
+ "users": {"$ref": "#/definitions/objectOfInts"},
+ "users_default": {"$ref": "#/definitions/int"},
+ },
+ "definitions": {
+ "int": {
+ "type": "integer",
+ "minimum": CANONICALJSON_MIN_INT,
+ "maximum": CANONICALJSON_MAX_INT,
+ },
+ "objectOfInts": {
+ "type": "object",
+ "additionalProperties": {"$ref": "#/definitions/int"},
+ },
+ },
+}
+
+
+def _create_power_level_validator():
+ validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
+
+ # by default jsonschema does not consider a frozendict to be an object so
+ # we need to use a custom type checker
+ # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types
+ type_checker = validator.TYPE_CHECKER.redefine(
+ "object", lambda checker, thing: isinstance(thing, collections.abc.Mapping)
+ )
+
+ return jsonschema.validators.extend(validator, type_checker=type_checker)
+
+
+plValidator = _create_power_level_validator()
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index cdcbdd772b..154e5b7028 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -48,7 +48,8 @@ logger = logging.getLogger(__name__)
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
- "jsonschema>=2.5.1",
+ # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
+ "jsonschema>=3.0.0",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.4.0",
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 91d0762cb0..c0de4c93a8 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import Codes
+from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin
from synapse.rest.client import login, room, sync
@@ -203,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase):
tok=self.admin_access_token,
expect_code=200, # expect success
)
+
+ def test_cannot_set_string_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL "0"
+ room_power_levels["users"].update({self.user_user_id: "0"})
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_large_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL above the max safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MAX_INT + 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_small_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL below the minimum safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MIN_INT - 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
--
cgit 1.5.1
From 96715d763362a7027c39d571cfde3aa8b7b82fcf Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 26 Aug 2021 18:34:57 +0100
Subject: Make `backfill` and `get_missing_events` use the same codepath
(#10645)
Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code.
---
changelog.d/10645.misc | 1 +
synapse/handlers/federation.py | 273 ++++---------------------
synapse/storage/databases/main/purge_events.py | 1 +
3 files changed, 42 insertions(+), 233 deletions(-)
create mode 100644 changelog.d/10645.misc
(limited to 'synapse')
diff --git a/changelog.d/10645.misc b/changelog.d/10645.misc
new file mode 100644
index 0000000000..ac19263cd8
--- /dev/null
+++ b/changelog.d/10645.misc
@@ -0,0 +1 @@
+Make `backfill` and `get_missing_events` use the same codepath.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 246df43501..6fa2fc8f52 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
@@ -116,10 +117,6 @@ class _NewEventInfo:
Attributes:
event: the received event
- state: the state at that event, according to /state_ids from a remote
- homeserver. Only populated for backfilled events which are going to be a
- new backwards extremity.
-
claimed_auth_event_map: a map of (type, state_key) => event for the event's
claimed auth_events.
@@ -134,7 +131,6 @@ class _NewEventInfo:
"""
event: EventBase
- state: Optional[Sequence[EventBase]]
claimed_auth_event_map: StateMap[EventBase]
@@ -443,113 +439,7 @@ class FederationHandler(BaseHandler):
return
logger.info("Got %d prev_events", len(missing_events))
- await self._process_pulled_events(origin, missing_events)
-
- async def _get_state_for_room(
- self,
- destination: str,
- room_id: str,
- event_id: str,
- ) -> List[EventBase]:
- """Requests all of the room state at a given event from a remote
- homeserver.
-
- Will also fetch any missing events reported in the `auth_chain_ids`
- section of `/state_ids`.
-
- Args:
- destination: The remote homeserver to query for the state.
- room_id: The id of the room we're interested in.
- event_id: The id of the event we want the state at.
-
- Returns:
- A list of events in the state, not including the event itself.
- """
- (
- state_event_ids,
- auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
- destination, room_id, event_id=event_id
- )
-
- # Fetch the state events from the DB, and check we have the auth events.
- event_map = await self.store.get_events(state_event_ids, allow_rejected=True)
- auth_events_in_store = await self.store.have_seen_events(
- room_id, auth_event_ids
- )
-
- # Check for missing events. We handle state and auth event seperately,
- # as we want to pull the state from the DB, but we don't for the auth
- # events. (Note: we likely won't use the majority of the auth chain, and
- # it can be *huge* for large rooms, so it's worth ensuring that we don't
- # unnecessarily pull it from the DB).
- missing_state_events = set(state_event_ids) - set(event_map)
- missing_auth_events = set(auth_event_ids) - set(auth_events_in_store)
- if missing_state_events or missing_auth_events:
- await self._get_events_and_persist(
- destination=destination,
- room_id=room_id,
- events=missing_state_events | missing_auth_events,
- )
-
- if missing_state_events:
- new_events = await self.store.get_events(
- missing_state_events, allow_rejected=True
- )
- event_map.update(new_events)
-
- missing_state_events.difference_update(new_events)
-
- if missing_state_events:
- logger.warning(
- "Failed to fetch missing state events for %s %s",
- event_id,
- missing_state_events,
- )
-
- if missing_auth_events:
- auth_events_in_store = await self.store.have_seen_events(
- room_id, missing_auth_events
- )
- missing_auth_events.difference_update(auth_events_in_store)
-
- if missing_auth_events:
- logger.warning(
- "Failed to fetch missing auth events for %s %s",
- event_id,
- missing_auth_events,
- )
-
- remote_state = list(event_map.values())
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event.event_id, event.room_id)
- for event in remote_state
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned auth/state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
-
- if bad_events:
- remote_state = [e for e in remote_state if e.room_id == room_id]
-
- return remote_state
+ await self._process_pulled_events(origin, missing_events, backfilled=False)
async def _get_state_after_missing_prev_event(
self,
@@ -567,10 +457,6 @@ class FederationHandler(BaseHandler):
Returns:
A list of events in the state, including the event itself
"""
- # TODO: This function is basically the same as _get_state_for_room. Can
- # we make backfill() use it, rather than having two code paths? I think the
- # only difference is that backfill() persists the prev events separately.
-
(
state_event_ids,
auth_event_ids,
@@ -681,6 +567,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
+ backfilled: bool = False,
) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -693,6 +580,9 @@ class FederationHandler(BaseHandler):
state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
+
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
"""
logger.debug("Processing event: %s", event)
@@ -700,10 +590,15 @@ class FederationHandler(BaseHandler):
context = await self.state_handler.compute_event_context(
event, old_state=state
)
- await self._auth_and_persist_event(origin, event, context, state=state)
+ await self._auth_and_persist_event(
+ origin, event, context, state=state, backfilled=backfilled
+ )
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+ if backfilled:
+ return
+
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@@ -868,7 +763,7 @@ class FederationHandler(BaseHandler):
@log_function
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: List[str]
- ) -> List[EventBase]:
+ ) -> None:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -878,6 +773,9 @@ class FederationHandler(BaseHandler):
sanity-checking on them. If any of the backfilled events are invalid,
this method throws a SynapseError.
+ We might also raise an InvalidResponseError if the response from the remote
+ server is just bogus.
+
TODO: make this more useful to distinguish failures of the remote
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
@@ -890,111 +788,18 @@ class FederationHandler(BaseHandler):
)
if not events:
- return []
-
- # ideally we'd sanity check the events here for excess prev_events etc,
- # but it's hard to reject events at this point without completely
- # breaking backfill in the same way that it is currently broken by
- # events whose signature we cannot verify (#3121).
- #
- # So for now we accept the events anyway. #3124 tracks this.
- #
- # for ev in events:
- # self._sanity_check_event(ev)
-
- # Don't bother processing events we already have.
- seen_events = await self.store.have_events_in_timeline(
- {e.event_id for e in events}
- )
-
- events = [e for e in events if e.event_id not in seen_events]
-
- if not events:
- return []
-
- event_map = {e.event_id: e for e in events}
-
- event_ids = {e.event_id for e in events}
-
- # build a list of events whose prev_events weren't in the batch.
- # (XXX: this will include events whose prev_events we already have; that doesn't
- # sound right?)
- edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
-
- logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
-
- # For each edge get the current state.
-
- state_events = {}
- events_to_state = {}
- for e_id in edges:
- state = await self._get_state_for_room(
- destination=dest,
- room_id=room_id,
- event_id=e_id,
- )
- state_events.update({s.event_id: s for s in state})
- events_to_state[e_id] = state
+ return
- required_auth = {
- a_id
- for event in events + list(state_events.values())
- for a_id in event.auth_event_ids()
- }
- auth_events = await self.store.get_events(required_auth, allow_rejected=True)
- auth_events.update(
- {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
- )
-
- ev_infos = []
-
- # Step 1: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities), with custom auth events and state
- for e_id in events_to_state:
- # For paranoia we ensure that these events are marked as
- # non-outliers
- ev = event_map[e_id]
- assert not ev.internal_metadata.is_outlier()
-
- ev_infos.append(
- _NewEventInfo(
- event=ev,
- state=events_to_state[e_id],
- claimed_auth_event_map={
- (
- auth_events[a_id].type,
- auth_events[a_id].state_key,
- ): auth_events[a_id]
- for a_id in ev.auth_event_ids()
- if a_id in auth_events
- },
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
)
- )
-
- if ev_infos:
- await self._auth_and_persist_events(
- dest, room_id, ev_infos, backfilled=True
- )
-
- # Step 2: Persist the rest of the events in the chunk one by one
- events.sort(key=lambda e: e.depth)
-
- for event in events:
- if event in events_to_state:
- continue
-
- # For paranoia we ensure that these events are marked as
- # non-outliers
- assert not event.internal_metadata.is_outlier()
-
- context = await self.state_handler.compute_event_context(event)
-
- # We store these one at a time since each event depends on the
- # previous to work out the state.
- # TODO: We can probably do something more clever here.
- await self._auth_and_persist_event(dest, event, context, backfilled=True)
- return events
+ await self._process_pulled_events(dest, events, backfilled=True)
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
@@ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler):
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
return True
- except SynapseError as e:
+ except (SynapseError, InvalidResponseError) as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
except HttpResponseException as e:
@@ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler):
else:
logger.info("Missing auth event %s", auth_event_id)
- event_infos.append(_NewEventInfo(event, None, auth))
+ event_infos.append(_NewEventInfo(event, auth))
if event_infos:
await self._auth_and_persist_events(
@@ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler):
)
async def _process_pulled_events(
- self, origin: str, events: Iterable[EventBase]
+ self, origin: str, events: Iterable[EventBase], backfilled: bool
) -> None:
"""Process a batch of events we have pulled from a remote server
@@ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler):
Params:
origin: The server we received these events from
events: The received events.
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
"""
# We want to sort these by depth so we process them and
@@ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler):
for ev in sorted_events:
with nested_logging_context(ev.event_id):
- await self._process_pulled_event(origin, ev)
+ await self._process_pulled_event(origin, ev, backfilled=backfilled)
- async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+ async def _process_pulled_event(
+ self, origin: str, event: EventBase, backfilled: bool
+ ) -> None:
"""Process a single event that we have pulled from a remote server
Pulls in any events required to auth the event, persists the received event,
@@ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler):
Params:
origin: The server we received this event from
events: The received event
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
"""
logger.info("Processing pulled event %s", event)
@@ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler):
try:
state = await self._resolve_state_at_missing_prevs(origin, event)
- await self._process_received_pdu(origin, event, state=state)
+ await self._process_received_pdu(
+ origin, event, state=state, backfilled=backfilled
+ )
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
@@ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler):
origin: str,
room_id: str,
event_infos: Collection[_NewEventInfo],
- backfilled: bool = False,
) -> None:
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
@@ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
- res = await self.state_handler.compute_event_context(
- event, old_state=ev_info.state
- )
+ res = await self.state_handler.compute_event_context(event)
res = await self._check_event_auth(
origin,
event,
res,
- state=ev_info.state,
claimed_auth_event_map=ev_info.claimed_auth_event_map,
- backfilled=backfilled,
)
return res
@@ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler):
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
- backfilled=backfilled,
)
async def _persist_auth_tree(
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65dac5..bccff5e5b9 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
+ self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")
--
cgit 1.5.1
From 1800aabfc226938036479d2ab1a750aa34cf3974 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 26 Aug 2021 21:41:44 +0100
Subject: Split `FederationHandler` in half (#10692)
The idea here is to take anything to do with incoming events and move it out to a separate handler, as a way of making FederationHandler smaller.
---
changelog.d/10692.misc | 1 +
synapse/federation/federation_server.py | 7 +-
synapse/handlers/federation.py | 1789 +-------------------
synapse/handlers/federation_event.py | 1825 +++++++++++++++++++++
synapse/replication/http/federation.py | 4 +-
synapse/server.py | 5 +
tests/federation/transport/test_knocking.py | 2 +-
tests/handlers/test_federation.py | 16 +-
tests/handlers/test_presence.py | 4 +-
tests/replication/test_federation_sender_shard.py | 2 +-
tests/test_federation.py | 10 +-
11 files changed, 1884 insertions(+), 1781 deletions(-)
create mode 100644 changelog.d/10692.misc
create mode 100644 synapse/handlers/federation_event.py
(limited to 'synapse')
diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc
new file mode 100644
index 0000000000..a1b0def76b
--- /dev/null
+++ b/changelog.d/10692.misc
@@ -0,0 +1 @@
+Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e1b58d40c5..214ee948fa 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -110,6 +110,7 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -787,7 +788,9 @@ class FederationServer(FederationBase):
event = await self._check_sigs_and_hash(room_version, event)
- return await self.handler.on_send_membership_event(origin, event)
+ return await self._federation_event_handler.on_send_membership_event(
+ origin, event
+ )
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
@@ -1005,7 +1008,7 @@ class FederationServer(FederationBase):
async with lock:
logger.info("handling received PDU: %s", event)
try:
- await self.handler.on_receive_pdu(origin, event)
+ await self._federation_event_handler.on_receive_pdu(origin, event)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6fa2fc8f52..daf1d3bfb3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,23 +17,9 @@
import itertools
import logging
-from collections.abc import Container
from http import HTTPStatus
-from typing import (
- TYPE_CHECKING,
- Collection,
- Dict,
- Iterable,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
-import attr
-from prometheus_client import Counter
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -41,19 +27,12 @@ from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse import event_auth
-from synapse.api.constants import (
- EventContentFields,
- EventTypes,
- Membership,
- RejectedReason,
- RoomEncryptionAlgorithms,
-)
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
Codes,
FederationDeniedError,
- FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
@@ -61,7 +40,6 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
-from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
@@ -75,28 +53,14 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.logging.utils import log_function
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
- ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import (
- JsonDict,
- MutableStateMap,
- PersistedEventPosition,
- RoomStreamToken,
- StateMap,
- UserID,
- get_domain_from_id,
-)
-from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.iterutils import batch_iter
+from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
-from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
@@ -104,45 +68,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-soft_failed_event_counter = Counter(
- "synapse_federation_soft_failed_events_total",
- "Events received over federation that we marked as soft_failed",
-)
-
-
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _NewEventInfo:
- """Holds information about a received event, ready for passing to _auth_and_persist_events
-
- Attributes:
- event: the received event
-
- claimed_auth_event_map: a map of (type, state_key) => event for the event's
- claimed auth_events.
-
- This can include events which have not yet been persisted, in the case that
- we are backfilling a batch of events.
-
- Note: May be incomplete: if we were unable to find all of the claimed auth
- events. Also, treat the contents with caution: the events might also have
- been rejected, might not yet have been authorized themselves, or they might
- be in the wrong room.
-
- """
-
- event: EventBase
- claimed_auth_event_map: StateMap[EventBase]
-
class FederationHandler(BaseHandler):
- """Handles events that originated from federation.
- Responsible for:
- a) handling received Pdus before handing them on as Events to the rest
- of the homeserver (including auth and state conflict resolutions)
- b) converting events that were produced by local clients that may need
- to be sent to remote homeservers.
- c) doing the necessary dances to invite remote users and join remote
- rooms.
+ """Handles general incoming federation requests
+
+ Incoming events are *not* handled here, for which see FederationEventHandler.
"""
def __init__(self, hs: "HomeServer"):
@@ -155,652 +85,35 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
- self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
- self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._event_auth_handler = hs.get_event_auth_handler()
- self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
- self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
+ self._federation_event_handler = hs.get_federation_event_handler()
- self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
if hs.config.worker_app:
- self._user_device_resync = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
- self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_outlier_membership = (
self.store.maybe_store_room_on_outlier_membership
)
- # When joining a room we need to queue any events for that room up.
- # For each room, a list of (pdu, origin) tuples.
- self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
- self._room_pdu_linearizer = Linearizer("fed_room_pdu")
-
self._room_backfill = Linearizer("room_backfill")
self.third_party_event_rules = hs.get_third_party_event_rules()
- self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
-
- async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
- """Process a PDU received via a federation /send/ transaction
-
- Args:
- origin: server which initiated the /send/ transaction. Will
- be used to fetch missing events or state.
- pdu: received PDU
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- # We reprocess pdus when we have seen them only as outliers
- existing = await self.store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
-
- # FIXME: Currently we fetch an event again when we already have it
- # if it has been marked as an outlier.
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received event %s which we have already seen", event_id
- )
- return
- if pdu.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received outlier %s which we already have as an outlier",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
- # do some initial sanity-checking of the event. In particular, make
- # sure it doesn't have hundreds of prev_events or auth_events, which
- # could cause a huge state resolution or cascade of event fetches.
- try:
- self._sanity_check_event(pdu)
- except SynapseError as err:
- logger.warning("Received event failed sanity checks")
- raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
-
- # If we are currently in the process of joining this room, then we
- # queue up events for later processing.
- if room_id in self.room_queues:
- logger.info(
- "Queuing PDU from %s for now: join in progress",
- origin,
- )
- self.room_queues[room_id].append((pdu, origin))
- return
-
- # If we're not in the room just ditch the event entirely. This is
- # probably an old server that has come back and thinks we're still in
- # the room (or we've been rejoined to the room by a state reset).
- #
- # Note that if we were never in the room then we would have already
- # dropped the event, since we wouldn't know the room version.
- is_in_room = await self._event_auth_handler.check_host_in_room(
- room_id, self.server_name
- )
- if not is_in_room:
- logger.info(
- "Ignoring PDU from %s as we're not in the room",
- origin,
- )
- return None
-
- # Check that the event passes auth based on the state at the event. This is
- # done for events that are to be added to the timeline (non-outliers).
- #
- # Get missing pdus if necessary:
- # - Fetching any missing prev events to fill in gaps in the graph
- # - Fetching state if we have a hole in the graph
- if not pdu.internal_metadata.is_outlier():
- prevs = set(pdu.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if missing_prevs:
- # We only backfill backwards to the min depth.
- min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("min_depth: %d", min_depth)
-
- if min_depth is not None and pdu.depth > min_depth:
- # If we're missing stuff, ensure we only fetch stuff one
- # at a time.
- logger.info(
- "Acquiring room lock to fetch %d missing prev_events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- with (await self._room_pdu_linearizer.queue(pdu.room_id)):
- logger.info(
- "Acquired room lock to fetch %d missing prev_events",
- len(missing_prevs),
- )
-
- try:
- await self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
- except Exception as e:
- raise Exception(
- "Error fetching missing prev_events for %s: %s"
- % (event_id, e)
- ) from e
-
- # Update the set of things we've seen after trying to
- # fetch the missing stuff
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- logger.info("Found all missing prev_events")
-
- if missing_prevs:
- # since this event was pushed to us, it is possible for it to
- # become the only forward-extremity in the room, and we would then
- # trust its state to be the state for the whole room. This is very
- # bad. Further, if the event was pushed to us, there is no excuse
- # for us not to have all the prev_events. (XXX: apart from
- # min_depth?)
- #
- # We therefore reject any such events.
- logger.warning(
- "Rejecting: failed to fetch %d prev events: %s",
- len(missing_prevs),
- shortstr(missing_prevs),
- )
- raise FederationError(
- "ERROR",
- 403,
- (
- "Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- affected=pdu.event_id,
- )
-
- await self._process_received_pdu(origin, pdu, state=None)
-
- async def _get_missing_events_for_pdu(
- self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
- ) -> None:
- """
- Args:
- origin: Origin of the pdu. Will be called to get the missing events
- pdu: received pdu
- prevs: List of event ids which we are missing
- min_depth: Minimum depth of events to return.
- """
-
- room_id = pdu.room_id
- event_id = pdu.event_id
-
- seen = await self.store.have_events_in_timeline(prevs)
-
- if not prevs - seen:
- return
-
- latest_list = await self.store.get_latest_event_ids_in_room(room_id)
-
- # We add the prev events that we have seen to the latest
- # list to ensure the remote server doesn't give them to us
- latest = set(latest_list)
- latest |= seen
-
- logger.info(
- "Requesting missing events between %s and %s",
- shortstr(latest),
- event_id,
- )
-
- # XXX: we set timeout to 10s to help workaround
- # https://github.com/matrix-org/synapse/issues/1733.
- # The reason is to avoid holding the linearizer lock
- # whilst processing inbound /send transactions, causing
- # FDs to stack up and block other inbound transactions
- # which empirically can currently take up to 30 minutes.
- #
- # N.B. this explicitly disables retry attempts.
- #
- # N.B. this also increases our chances of falling back to
- # fetching fresh state for the room if the missing event
- # can't be found, which slightly reduces our security.
- # it may also increase our DAG extremity count for the room,
- # causing additional state resolution? See #1760.
- # However, fetching state doesn't hold the linearizer lock
- # apparently.
- #
- # see https://github.com/matrix-org/synapse/pull/1744
- #
- # ----
- #
- # Update richvdh 2018/09/18: There are a number of problems with timing this
- # request out aggressively on the client side:
- #
- # - it plays badly with the server-side rate-limiter, which starts tarpitting you
- # if you send too many requests at once, so you end up with the server carefully
- # working through the backlog of your requests, which you have already timed
- # out.
- #
- # - for this request in particular, we now (as of
- # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
- # server can't produce a plausible-looking set of prev_events - so we becone
- # much more likely to reject the event.
- #
- # - contrary to what it says above, we do *not* fall back to fetching fresh state
- # for the room if get_missing_events times out. Rather, we give up processing
- # the PDU whose prevs we are missing, which then makes it much more likely that
- # we'll end up back here for the *next* PDU in the list, which exacerbates the
- # problem.
- #
- # - the aggressive 10s timeout was introduced to deal with incoming federation
- # requests taking 8 hours to process. It's not entirely clear why that was going
- # on; certainly there were other issues causing traffic storms which are now
- # resolved, and I think in any case we may be more sensible about our locking
- # now. We're *certainly* more sensible about our logging.
- #
- # All that said: Let's try increasing the timeout to 60s and see what happens.
-
- try:
- missing_events = await self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
- except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
- # We failed to get the missing events, but since we need to handle
- # the case of `get_missing_events` not returning the necessary
- # events anyway, it is safe to simply log the error and continue.
- logger.warning("Failed to get prev_events: %s", e)
- return
-
- logger.info("Got %d prev_events", len(missing_events))
- await self._process_pulled_events(origin, missing_events, backfilled=False)
-
- async def _get_state_after_missing_prev_event(
- self,
- destination: str,
- room_id: str,
- event_id: str,
- ) -> List[EventBase]:
- """Requests all of the room state at a given event from a remote homeserver.
-
- Args:
- destination: The remote homeserver to query for the state.
- room_id: The id of the room we're interested in.
- event_id: The id of the event we want the state at.
-
- Returns:
- A list of events in the state, including the event itself
- """
- (
- state_event_ids,
- auth_event_ids,
- ) = await self.federation_client.get_room_state_ids(
- destination, room_id, event_id=event_id
- )
-
- logger.debug(
- "state_ids returned %i state events, %i auth events",
- len(state_event_ids),
- len(auth_event_ids),
- )
-
- # start by just trying to fetch the events from the store
- desired_events = set(state_event_ids)
- desired_events.add(event_id)
- logger.debug("Fetching %i events from cache/store", len(desired_events))
- fetched_events = await self.store.get_events(
- desired_events, allow_rejected=True
- )
-
- missing_desired_events = desired_events - fetched_events.keys()
- logger.debug(
- "We are missing %i events (got %i)",
- len(missing_desired_events),
- len(fetched_events),
- )
-
- # We probably won't need most of the auth events, so let's just check which
- # we have for now, rather than thrashing the event cache with them all
- # unnecessarily.
-
- # TODO: we probably won't actually need all of the auth events, since we
- # already have a bunch of the state events. It would be nice if the
- # federation api gave us a way of finding out which we actually need.
-
- missing_auth_events = set(auth_event_ids) - fetched_events.keys()
- missing_auth_events.difference_update(
- await self.store.have_seen_events(room_id, missing_auth_events)
- )
- logger.debug("We are also missing %i auth events", len(missing_auth_events))
-
- missing_events = missing_desired_events | missing_auth_events
- logger.debug("Fetching %i events from remote", len(missing_events))
- await self._get_events_and_persist(
- destination=destination, room_id=room_id, events=missing_events
- )
-
- # we need to make sure we re-load from the database to get the rejected
- # state correct.
- fetched_events.update(
- await self.store.get_events(missing_desired_events, allow_rejected=True)
- )
-
- # check for events which were in the wrong room.
- #
- # this can happen if a remote server claims that the state or
- # auth_events at an event in room A are actually events in room B
-
- bad_events = [
- (event_id, event.room_id)
- for event_id, event in fetched_events.items()
- if event.room_id != room_id
- ]
-
- for bad_event_id, bad_room_id in bad_events:
- # This is a bogus situation, but since we may only discover it a long time
- # after it happened, we try our best to carry on, by just omitting the
- # bad events from the returned state set.
- logger.warning(
- "Remote server %s claims event %s in room %s is an auth/state "
- "event in room %s",
- destination,
- bad_event_id,
- bad_room_id,
- room_id,
- )
-
- del fetched_events[bad_event_id]
-
- # if we couldn't get the prev event in question, that's a problem.
- remote_event = fetched_events.get(event_id)
- if not remote_event:
- raise Exception("Unable to get missing prev_event %s" % (event_id,))
-
- # missing state at that event is a warning, not a blocker
- # XXX: this doesn't sound right? it means that we'll end up with incomplete
- # state.
- failed_to_fetch = desired_events - fetched_events.keys()
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state events for %s %s",
- event_id,
- failed_to_fetch,
- )
-
- remote_state = [
- fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
- ]
-
- if remote_event.is_state() and remote_event.rejected_reason is None:
- remote_state.append(remote_event)
-
- return remote_state
-
- async def _process_received_pdu(
- self,
- origin: str,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- backfilled: bool = False,
- ) -> None:
- """Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
-
- Args:
- origin: server sending the event
-
- event: event to be persisted
-
- state: Normally None, but if we are handling a gap in the graph
- (ie, we are missing one or more prev_events), the resolved state at the
- event
-
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
- logger.debug("Processing event: %s", event)
-
- try:
- context = await self.state_handler.compute_event_context(
- event, old_state=state
- )
- await self._auth_and_persist_event(
- origin, event, context, state=state, backfilled=backfilled
- )
- except AuthError as e:
- raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
-
- if backfilled:
- return
-
- # For encrypted messages we check that we know about the sending device,
- # if we don't then we mark the device cache for that user as stale.
- if event.type == EventTypes.Encrypted:
- device_id = event.content.get("device_id")
- sender_key = event.content.get("sender_key")
-
- cached_devices = await self.store.get_cached_devices_for_user(event.sender)
-
- resync = False # Whether we should resync device lists.
-
- device = None
- if device_id is not None:
- device = cached_devices.get(device_id)
- if device is None:
- logger.info(
- "Received event from remote device not in our cache: %s %s",
- event.sender,
- device_id,
- )
- resync = True
-
- # We also check if the `sender_key` matches what we expect.
- if sender_key is not None:
- # Figure out what sender key we're expecting. If we know the
- # device and recognize the algorithm then we can work out the
- # exact key to expect. Otherwise check it matches any key we
- # have for that device.
-
- current_keys: Container[str] = []
-
- if device:
- keys = device.get("keys", {}).get("keys", {})
-
- if (
- event.content.get("algorithm")
- == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
- ):
- # For this algorithm we expect a curve25519 key.
- key_name = "curve25519:%s" % (device_id,)
- current_keys = [keys.get(key_name)]
- else:
- # We don't know understand the algorithm, so we just
- # check it matches a key for the device.
- current_keys = keys.values()
- elif device_id:
- # We don't have any keys for the device ID.
- pass
- else:
- # The event didn't include a device ID, so we just look for
- # keys across all devices.
- current_keys = [
- key
- for device in cached_devices.values()
- for key in device.get("keys", {}).get("keys", {}).values()
- ]
-
- # We now check that the sender key matches (one of) the expected
- # keys.
- if sender_key not in current_keys:
- logger.info(
- "Received event from remote device with unexpected sender key: %s %s: %s",
- event.sender,
- device_id or "",
- sender_key,
- )
- resync = True
-
- if resync:
- run_as_background_process(
- "resync_device_due_to_pdu", self._resync_device, event.sender
- )
-
- await self._handle_marker_event(origin, event)
-
- async def _handle_marker_event(self, origin: str, marker_event: EventBase):
- """Handles backfilling the insertion event when we receive a marker
- event that points to one.
-
- Args:
- origin: Origin of the event. Will be called to get the insertion event
- marker_event: The event to process
- """
-
- if marker_event.type != EventTypes.MSC2716_MARKER:
- # Not a marker event
- return
-
- if marker_event.rejected_reason is not None:
- # Rejected event
- return
-
- # Skip processing a marker event if the room version doesn't
- # support it.
- room_version = await self.store.get_room_version(marker_event.room_id)
- if not room_version.msc2716_historical:
- return
-
- logger.debug("_handle_marker_event: received %s", marker_event)
-
- insertion_event_id = marker_event.content.get(
- EventContentFields.MSC2716_MARKER_INSERTION
- )
-
- if insertion_event_id is None:
- # Nothing to retrieve then (invalid marker)
- return
-
- logger.debug(
- "_handle_marker_event: backfilling insertion event %s", insertion_event_id
- )
-
- await self._get_events_and_persist(
- origin,
- marker_event.room_id,
- [insertion_event_id],
- )
-
- insertion_event = await self.store.get_event(
- insertion_event_id, allow_none=True
- )
- if insertion_event is None:
- logger.warning(
- "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
- origin,
- insertion_event_id,
- marker_event.event_id,
- )
- return
-
- logger.debug(
- "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- await self.store.insert_insertion_extremity(
- insertion_event_id, marker_event.room_id
- )
-
- logger.debug(
- "_handle_marker_event: insertion extremity added for %s from marker event %s",
- insertion_event,
- marker_event,
- )
-
- async def _resync_device(self, sender: str) -> None:
- """We have detected that the device list for the given user may be out
- of sync, so we try and resync them.
- """
-
- try:
- await self.store.mark_remote_user_device_cache_as_stale(sender)
-
- # Immediately attempt a resync in the background
- if self.config.worker_app:
- await self._user_device_resync(user_id=sender)
- else:
- await self._device_list_updater.user_device_resync(sender)
- except Exception:
- logger.exception("Failed to resync device for %s", sender)
-
- @log_function
- async def backfill(
- self, dest: str, room_id: str, limit: int, extremities: List[str]
- ) -> None:
- """Trigger a backfill request to `dest` for the given `room_id`
-
- This will attempt to get more events from the remote. If the other side
- has no new events to offer, this will return an empty list.
-
- As the events are received, we check their signatures, and also do some
- sanity-checking on them. If any of the backfilled events are invalid,
- this method throws a SynapseError.
-
- We might also raise an InvalidResponseError if the response from the remote
- server is just bogus.
-
- TODO: make this more useful to distinguish failures of the remote
- server from invalid events (there is probably no point in trying to
- re-fetch invalid events from every other HS in the room.)
- """
- if dest == self.server_name:
- raise SynapseError(400, "Can't backfill from self.")
-
- events = await self.federation_client.backfill(
- dest, room_id, limit=limit, extremities=extremities
- )
-
- if not events:
- return
-
- # if there are any events in the wrong room, the remote server is buggy and
- # should not be trusted.
- for ev in events:
- if ev.room_id != room_id:
- raise InvalidResponseError(
- f"Remote server {dest} returned event {ev.event_id} which is in "
- f"room {ev.room_id}, when we were backfilling in {room_id}"
- )
-
- await self._process_pulled_events(dest, events, backfilled=True)
-
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
) -> bool:
@@ -995,7 +308,7 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- await self.backfill(
+ await self._federation_event_handler.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1084,316 +397,7 @@ class FederationHandler(BaseHandler):
tried_domains.update(dom for dom, _ in likely_extremeties_domains)
- return False
-
- async def _get_events_and_persist(
- self, destination: str, room_id: str, events: Iterable[str]
- ) -> None:
- """Fetch the given events from a server, and persist them as outliers.
-
- This function *does not* recursively get missing auth events of the
- newly fetched events. Callers must include in the `events` argument
- any missing events from the auth chain.
-
- Logs a warning if we can't find the given event.
- """
-
- room_version = await self.store.get_room_version(room_id)
-
- event_map: Dict[str, EventBase] = {}
-
- async def get_event(event_id: str):
- with nested_logging_context(event_id):
- try:
- event = await self.federation_client.get_pdu(
- [destination],
- event_id,
- room_version,
- outlier=True,
- )
- if event is None:
- logger.warning(
- "Server %s didn't return event %s",
- destination,
- event_id,
- )
- return
-
- event_map[event.event_id] = event
-
- except Exception as e:
- logger.warning(
- "Error fetching missing state/auth event %s: %s %s",
- event_id,
- type(e),
- e,
- )
-
- await concurrently_execute(get_event, events, 5)
-
- # Make a map of auth events for each event. We do this after fetching
- # all the events as some of the events' auth events will be in the list
- # of requested events.
-
- auth_events = [
- aid
- for event in event_map.values()
- for aid in event.auth_event_ids()
- if aid not in event_map
- ]
- persisted_events = await self.store.get_events(
- auth_events,
- allow_rejected=True,
- )
-
- event_infos = []
- for event in event_map.values():
- auth = {}
- for auth_event_id in event.auth_event_ids():
- ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
- if ae:
- auth[(ae.type, ae.state_key)] = ae
- else:
- logger.info("Missing auth event %s", auth_event_id)
-
- event_infos.append(_NewEventInfo(event, auth))
-
- if event_infos:
- await self._auth_and_persist_events(
- destination,
- room_id,
- event_infos,
- )
-
- async def _process_pulled_events(
- self, origin: str, events: Iterable[EventBase], backfilled: bool
- ) -> None:
- """Process a batch of events we have pulled from a remote server
-
- Pulls in any events required to auth the events, persists the received events,
- and notifies clients, if appropriate.
-
- Assumes the events have already had their signatures and hashes checked.
-
- Params:
- origin: The server we received these events from
- events: The received events.
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
-
- # We want to sort these by depth so we process them and
- # tell clients about them in order.
- sorted_events = sorted(events, key=lambda x: x.depth)
-
- for ev in sorted_events:
- with nested_logging_context(ev.event_id):
- await self._process_pulled_event(origin, ev, backfilled=backfilled)
-
- async def _process_pulled_event(
- self, origin: str, event: EventBase, backfilled: bool
- ) -> None:
- """Process a single event that we have pulled from a remote server
-
- Pulls in any events required to auth the event, persists the received event,
- and notifies clients, if appropriate.
-
- Assumes the event has already had its signatures and hashes checked.
-
- This is somewhat equivalent to on_receive_pdu, but applies somewhat different
- logic in the case that we are missing prev_events (in particular, it just
- requests the state at that point, rather than triggering a get_missing_events) -
- so is appropriate when we have pulled the event from a remote server, rather
- than having it pushed to us.
-
- Params:
- origin: The server we received this event from
- events: The received event
- backfilled: True if this is part of a historical batch of events (inhibits
- notification to clients, and validation of device keys.)
- """
- logger.info("Processing pulled event %s", event)
-
- # these should not be outliers.
- assert not event.internal_metadata.is_outlier()
-
- event_id = event.event_id
-
- existing = await self.store.get_event(
- event_id, allow_none=True, allow_rejected=True
- )
- if existing:
- if not existing.internal_metadata.is_outlier():
- logger.info(
- "Ignoring received event %s which we have already seen",
- event_id,
- )
- return
- logger.info("De-outliering event %s", event_id)
-
- try:
- self._sanity_check_event(event)
- except SynapseError as err:
- logger.warning("Event %s failed sanity check: %s", event_id, err)
- return
-
- try:
- state = await self._resolve_state_at_missing_prevs(origin, event)
- await self._process_received_pdu(
- origin, event, state=state, backfilled=backfilled
- )
- except FederationError as e:
- if e.code == 403:
- logger.warning("Pulled event %s failed history check.", event_id)
- else:
- raise
-
- async def _resolve_state_at_missing_prevs(
- self, dest: str, event: EventBase
- ) -> Optional[Iterable[EventBase]]:
- """Calculate the state at an event with missing prev_events.
-
- This is used when we have pulled a batch of events from a remote server, and
- still don't have all the prev_events.
-
- If we already have all the prev_events for `event`, this method does nothing.
-
- Otherwise, the missing prevs become new backwards extremities, and we fall back
- to asking the remote server for the state after each missing `prev_event`,
- and resolving across them.
-
- That's ok provided we then resolve the state against other bits of the DAG
- before using it - in other words, that the received event `event` is not going
- to become the only forwards_extremity in the room (which will ensure that you
- can't just take over a room by sending an event, withholding its prev_events,
- and declaring yourself to be an admin in the subsequent state request).
-
- In other words: we should only call this method if `event` has been *pulled*
- as part of a batch of missing prev events, or similar.
-
- Params:
- dest: the remote server to ask for state at the missing prevs. Typically,
- this will be the server we got `event` from.
- event: an event to check for missing prevs.
-
- Returns:
- if we already had all the prev events, `None`. Otherwise, returns a list of
- the events in the state at `event`.
- """
- room_id = event.room_id
- event_id = event.event_id
-
- prevs = set(event.prev_event_ids())
- seen = await self.store.have_events_in_timeline(prevs)
- missing_prevs = prevs - seen
-
- if not missing_prevs:
- return None
-
- logger.info(
- "Event %s is missing prev_events %s: calculating state for a "
- "backwards extremity",
- event_id,
- shortstr(missing_prevs),
- )
- # Calculate the state after each of the previous events, and
- # resolve them to find the correct state at the current event.
- event_map = {event_id: event}
- try:
- # Get the state of the events we know about
- ours = await self.state_store.get_state_groups_ids(room_id, seen)
-
- # state_maps is a list of mappings from (type, state_key) to event_id
- state_maps: List[StateMap[str]] = list(ours.values())
-
- # we don't need this any more, let's delete it.
- del ours
-
- # Ask the remote server for the states we don't
- # know about
- for p in missing_prevs:
- logger.info("Requesting state after missing prev_event %s", p)
-
- with nested_logging_context(p):
- # note that if any of the missing prevs share missing state or
- # auth events, the requests to fetch those events are deduped
- # by the get_pdu_cache in federation_client.
- remote_state = await self._get_state_after_missing_prev_event(
- dest, room_id, p
- )
-
- remote_state_map = {
- (x.type, x.state_key): x.event_id for x in remote_state
- }
- state_maps.append(remote_state_map)
-
- for x in remote_state:
- event_map[x.event_id] = x
-
- room_version = await self.store.get_room_version_id(room_id)
- state_map = await self._state_resolution_handler.resolve_events_with_store(
- room_id,
- room_version,
- state_maps,
- event_map,
- state_res_store=StateResolutionStore(self.store),
- )
-
- # We need to give _process_received_pdu the actual state events
- # rather than event ids, so generate that now.
-
- # First though we need to fetch all the events that are in
- # state_map, so we can build up the state below.
- evs = await self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- redact_behaviour=EventRedactBehaviour.AS_IS,
- )
- event_map.update(evs)
-
- state = [event_map[e] for e in state_map.values()]
- except Exception:
- logger.warning(
- "Error attempting to resolve state at missing prev_events",
- exc_info=True,
- )
- raise FederationError(
- "ERROR",
- 403,
- "We can't get valid state history.",
- affected=event_id,
- )
- return state
-
- def _sanity_check_event(self, ev: EventBase) -> None:
- """
- Do some early sanity checks of a received event
-
- In particular, checks it doesn't have an excessive number of
- prev_events or auth_events, which could cause a huge state resolution
- or cascade of event fetches.
-
- Args:
- ev: event to be checked
-
- Raises:
- SynapseError if the event does not pass muster
- """
- if len(ev.prev_event_ids()) > 20:
- logger.warning(
- "Rejecting event %s which has %i prev_events",
- ev.event_id,
- len(ev.prev_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
-
- if len(ev.auth_event_ids()) > 10:
- logger.warning(
- "Rejecting event %s which has %i auth_events",
- ev.event_id,
- len(ev.auth_event_ids()),
- )
- raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
+ return False
async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
@@ -1460,9 +464,9 @@ class FederationHandler(BaseHandler):
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert room_id not in self.room_queues
+ assert room_id not in self._federation_event_handler.room_queues
- self.room_queues[room_id] = []
+ self._federation_event_handler.room_queues[room_id] = []
await self._clean_room_for_join(room_id)
@@ -1536,8 +540,8 @@ class FederationHandler(BaseHandler):
logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally:
- room_queue = self.room_queues[room_id]
- del self.room_queues[room_id]
+ room_queue = self._federation_event_handler.room_queues[room_id]
+ del self._federation_event_handler.room_queues[room_id]
# we don't need to wait for the queued events to be processed -
# it's just a best-effort thing at this point. We do want to do
@@ -1613,7 +617,7 @@ class FederationHandler(BaseHandler):
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event.event_id, stream_id
@@ -1633,7 +637,7 @@ class FederationHandler(BaseHandler):
p,
)
with nested_logging_context(p.event_id):
- await self.on_receive_pdu(origin, p)
+ await self._federation_event_handler.on_receive_pdu(origin, p)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1726,7 +730,7 @@ class FederationHandler(BaseHandler):
raise
# Ensure the user can even join the room.
- await self._check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(context, event)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1803,7 +807,9 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify(event.room_id, [(event, context)])
+ await self._federation_event_handler.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event
@@ -1830,7 +836,7 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify(
+ stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1973,116 +979,6 @@ class FederationHandler(BaseHandler):
return event
- @log_function
- async def on_send_membership_event(
- self, origin: str, event: EventBase
- ) -> Tuple[EventBase, EventContext]:
- """
- We have received a join/leave/knock event for a room via send_join/leave/knock.
-
- Verify that event and send it into the room on the remote homeserver's behalf.
-
- This is quite similar to on_receive_pdu, with the following principal
- differences:
- * only membership events are permitted (and only events with
- sender==state_key -- ie, no kicks or bans)
- * *We* send out the event on behalf of the remote server.
- * We enforce the membership restrictions of restricted rooms.
- * Rejected events result in an exception rather than being stored.
-
- There are also other differences, however it is not clear if these are by
- design or omission. In particular, we do not attempt to backfill any missing
- prev_events.
-
- Args:
- origin: The homeserver of the remote (joining/invited/knocking) user.
- event: The member event that has been signed by the remote homeserver.
-
- Returns:
- The event and context of the event after inserting it into the room graph.
-
- Raises:
- SynapseError if the event is not accepted into the room
- """
- logger.debug(
- "on_send_membership_event: Got event: %s, signatures: %s",
- event.event_id,
- event.signatures,
- )
-
- if get_domain_from_id(event.sender) != origin:
- logger.info(
- "Got send_membership request for user %r from different origin %s",
- event.sender,
- origin,
- )
- raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
- if event.sender != event.state_key:
- raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
-
- assert not event.internal_metadata.outlier
-
- # Send this event on behalf of the other server.
- #
- # The remote server isn't a full participant in the room at this point, so
- # may not have an up-to-date list of the other homeservers participating in
- # the room, so we send it on their behalf.
- event.internal_metadata.send_on_behalf_of = origin
-
- context = await self.state_handler.compute_event_context(event)
- context = await self._check_event_auth(origin, event, context)
- if context.rejected:
- raise SynapseError(
- 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
- )
-
- # for joins, we need to check the restrictions of restricted rooms
- if event.membership == Membership.JOIN:
- await self._check_join_restrictions(context, event)
-
- # for knock events, we run the third-party event rules. It's not entirely clear
- # why we don't do this for other sorts of membership events.
- if event.membership == Membership.KNOCK:
- event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
- if not event_allowed:
- logger.info("Sending of knock %s forbidden by third-party rules", event)
- raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN
- )
-
- # all looks good, we can persist the event.
- await self._run_push_actions_and_persist_event(event, context)
- return event, context
-
- async def _check_join_restrictions(
- self, context: EventContext, event: EventBase
- ) -> None:
- """Check that restrictions in restricted join rules are matched
-
- Called when we receive a join event via send_join.
-
- Raises an auth error if the restrictions are not matched.
- """
- prev_state_ids = await context.get_prev_state_ids()
-
- # Check if the user is already in the room or invited to the room.
- user_id = event.state_key
- prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
-
- # Check if the member should be allowed access via membership in a space.
- await self._event_auth_handler.check_restricted_join_rules(
- prev_state_ids,
- event.room_version,
- user_id,
- prev_member_event,
- )
-
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event."""
@@ -2183,126 +1079,6 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context: str) -> int:
- return await self.store.get_min_depth(context)
-
- async def _auth_and_persist_event(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> None:
- """
- Process an event by performing auth checks and then persisting to the database.
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers.
-
- backfilled: True if the event was backfilled.
- """
- context = await self._check_event_auth(
- origin,
- event,
- context,
- state=state,
- claimed_auth_event_map=claimed_auth_event_map,
- backfilled=backfilled,
- )
-
- await self._run_push_actions_and_persist_event(event, context, backfilled)
-
- async def _run_push_actions_and_persist_event(
- self, event: EventBase, context: EventContext, backfilled: bool = False
- ):
- """Run the push actions for a received event, and persist it.
-
- Args:
- event: The event itself.
- context: The event context.
- backfilled: True if the event was backfilled.
- """
- try:
- if (
- not event.internal_metadata.is_outlier()
- and not backfilled
- and not context.rejected
- and (await self.store.get_min_depth(event.room_id)) <= event.depth
- ):
- await self.action_generator.handle_push_actions_for_event(
- event, context
- )
-
- await self.persist_events_and_notify(
- event.room_id, [(event, context)], backfilled=backfilled
- )
- except Exception:
- run_in_background(
- self.store.remove_push_actions_from_staging, event.event_id
- )
- raise
-
- async def _auth_and_persist_events(
- self,
- origin: str,
- room_id: str,
- event_infos: Collection[_NewEventInfo],
- ) -> None:
- """Creates the appropriate contexts and persists events. The events
- should not depend on one another, e.g. this should be used to persist
- a bunch of outliers, but not a chunk of individual events that depend
- on each other for state calculations.
-
- Notifies about the events where appropriate.
- """
-
- if not event_infos:
- return
-
- async def prep(ev_info: _NewEventInfo):
- event = ev_info.event
- with nested_logging_context(suffix=event.event_id):
- res = await self.state_handler.compute_event_context(event)
- res = await self._check_event_auth(
- origin,
- event,
- res,
- claimed_auth_event_map=ev_info.claimed_auth_event_map,
- )
- return res
-
- contexts = await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(prep, ev_info) for ev_info in event_infos],
- consumeErrors=True,
- )
- )
-
- await self.persist_events_and_notify(
- room_id,
- [
- (ev_info.event, context)
- for ev_info, context in zip(event_infos, contexts)
- ],
- )
-
async def _persist_auth_tree(
self,
origin: str,
@@ -2400,7 +1176,7 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
if auth_events or state:
- await self.persist_events_and_notify(
+ await self._federation_event_handler.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
@@ -2412,108 +1188,10 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
- return await self.persist_events_and_notify(
+ return await self._federation_event_handler.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
- async def _check_for_soft_fail(
- self,
- event: EventBase,
- state: Optional[Iterable[EventBase]],
- backfilled: bool,
- origin: str,
- ) -> None:
- """Checks if we should soft fail the event; if so, marks the event as
- such.
-
- Args:
- event
- state: The state at the event if we don't have all the event's prev events
- backfilled: Whether the event is from backfill
- origin: The host the event originates from.
- """
- # For new (non-backfilled and non-outlier) events we check if the event
- # passes auth based on the current state. If it doesn't then we
- # "soft-fail" the event.
- if backfilled or event.internal_metadata.is_outlier():
- return
-
- extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids_list)
- prev_event_ids = set(event.prev_event_ids())
-
- if extrem_ids == prev_event_ids:
- # If they're the same then the current state is the same as the
- # state at the event, so no point rechecking auth for soft fail.
- return
-
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- # Calculate the "current state".
- if state is not None:
- # If we're explicitly given the state then we won't have all the
- # prev events, and so we have a gap in the graph. In this case
- # we want to be a little careful as we might have been down for
- # a while and have an incorrect view of the current state,
- # however we still want to do checks as gaps are easy to
- # maliciously manufacture.
- #
- # So we use a "current state" that is actually a state
- # resolution across the current forward extremities and the
- # given state at the event. This should correctly handle cases
- # like bans, especially with state res v2.
-
- state_sets_d = await self.state_store.get_state_groups(
- event.room_id, extrem_ids
- )
- state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
- state_sets.append(state)
- current_states = await self.state_handler.resolve_events(
- room_version, state_sets, event
- )
- current_state_ids: StateMap[str] = {
- k: e.event_id for k, e in current_states.items()
- }
- else:
- current_state_ids = await self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids
- )
-
- logger.debug(
- "Doing soft-fail check for %s: state %s",
- event.event_id,
- current_state_ids,
- )
-
- # Now check if event pass auth against said current state
- auth_types = auth_types_for_event(room_version_obj, event)
- current_state_ids_list = [
- e for k, e in current_state_ids.items() if k in auth_types
- ]
-
- auth_events_map = await self.store.get_events(current_state_ids_list)
- current_auth_events = {
- (e.type, e.state_key): e for e in auth_events_map.values()
- }
-
- try:
- event_auth.check(room_version_obj, event, auth_events=current_auth_events)
- except AuthError as e:
- logger.warning(
- "Soft-failing %r (from %s) because %s",
- event,
- e,
- origin,
- extra={
- "room_id": event.room_id,
- "mxid": event.sender,
- "hs": origin,
- },
- )
- soft_failed_event_counter.inc()
- event.internal_metadata.soft_failed = True
-
async def on_get_missing_events(
self,
origin: str,
@@ -2542,334 +1220,6 @@ class FederationHandler(BaseHandler):
return missing_events
- async def _check_event_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- state: Optional[Iterable[EventBase]] = None,
- claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
- backfilled: bool = False,
- ) -> EventContext:
- """
- Checks whether an event should be rejected (for failing auth checks).
-
- Args:
- origin: The host the event originates from.
- event: The event itself.
- context:
- The event context.
-
- state:
- The state events used to check the event for soft-fail. If this is
- not provided the current state events will be used.
-
- claimed_auth_event_map:
- A map of (type, state_key) => event for the event's claimed auth_events.
- Possibly incomplete, and possibly including events that are not yet
- persisted, or authed, or in the right room.
-
- Only populated where we may not already have persisted these events -
- for example, when populating outliers, or the state for a backwards
- extremity.
-
- backfilled: True if the event was backfilled.
-
- Returns:
- The updated context object.
- """
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
- if claimed_auth_event_map:
- # if we have a copy of the auth events from the event, use that as the
- # basis for auth.
- auth_events = claimed_auth_event_map
- else:
- # otherwise, we calculate what the auth events *should* be, and use that
- prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = self._event_auth_handler.compute_auth_events(
- event, prev_state_ids, for_verification=True
- )
- auth_events_x = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
-
- try:
- (
- context,
- auth_events_for_auth,
- ) = await self._update_auth_events_and_context_for_auth(
- origin, event, context, auth_events
- )
- except Exception:
- # We don't really mind if the above fails, so lets not fail
- # processing if it does. However, it really shouldn't fail so
- # let's still log as an exception since we'll still want to fix
- # any bugs.
- logger.exception(
- "Failed to double check auth events for %s with remote. "
- "Ignoring failure and continuing processing of event.",
- event.event_id,
- )
- auth_events_for_auth = auth_events
-
- try:
- event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
- except AuthError as e:
- logger.warning("Failed auth resolution for %r because %s", event, e)
- context.rejected = RejectedReason.AUTH_ERROR
-
- if not context.rejected:
- await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
- if event.type == EventTypes.GuestAccess and not context.rejected:
- await self.maybe_kick_guest_users(event)
-
- # If we are going to send this event over federation we precaclculate
- # the joined hosts.
- if event.internal_metadata.get_send_on_behalf_of():
- await self.event_creation_handler.cache_joined_hosts_for_event(
- event, context
- )
-
- return context
-
- async def _update_auth_events_and_context_for_auth(
- self,
- origin: str,
- event: EventBase,
- context: EventContext,
- input_auth_events: StateMap[EventBase],
- ) -> Tuple[EventContext, StateMap[EventBase]]:
- """Helper for _check_event_auth. See there for docs.
-
- Checks whether a given event has the expected auth events. If it
- doesn't then we talk to the remote server to compare state to see if
- we can come to a consensus (e.g. if one server missed some valid
- state).
-
- This attempts to resolve any potential divergence of state between
- servers, but is not essential and so failures should not block further
- processing of the event.
-
- Args:
- origin:
- event:
- context:
-
- input_auth_events:
- Map from (event_type, state_key) to event
-
- Normally, our calculated auth_events based on the state of the room
- at the event's position in the DAG, though occasionally (eg if the
- event is an outlier), may be the auth events claimed by the remote
- server.
-
- Returns:
- updated context, updated auth event map
- """
- # take a copy of input_auth_events before we modify it.
- auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
-
- event_auth_events = set(event.auth_event_ids())
-
- # missing_auth is the set of the event's auth_events which we don't yet have
- # in auth_events.
- missing_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- # if we have missing events, we need to fetch those events from somewhere.
- #
- # we start by checking if they are in the store, and then try calling /event_auth/.
- if missing_auth:
- have_events = await self.store.have_seen_events(event.room_id, missing_auth)
- logger.debug("Events %s are in the store", have_events)
- missing_auth.difference_update(have_events)
-
- if missing_auth:
- # If we don't have all the auth events, we need to get them.
- logger.info("auth_events contains unknown events: %s", missing_auth)
- try:
- try:
- remote_auth_chain = await self.federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
- except RequestSendFailed as e1:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to get event auth from remote: %s", e1)
- return context, auth_events
-
- seen_remotes = await self.store.have_seen_events(
- event.room_id, [e.event_id for e in remote_auth_chain]
- )
-
- for e in remote_auth_chain:
- if e.event_id in seen_remotes:
- continue
-
- if e.event_id == event.event_id:
- continue
-
- try:
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in remote_auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- e.internal_metadata.outlier = True
-
- logger.debug(
- "_check_event_auth %s missing_auth: %s",
- event.event_id,
- e.event_id,
- )
- missing_auth_event_context = (
- await self.state_handler.compute_event_context(e)
- )
- await self._auth_and_persist_event(
- origin,
- e,
- missing_auth_event_context,
- claimed_auth_event_map=auth,
- )
-
- if e.event_id in event_auth_events:
- auth_events[(e.type, e.state_key)] = e
- except AuthError:
- pass
-
- except Exception:
- logger.exception("Failed to get auth chain")
-
- if event.internal_metadata.is_outlier():
- # XXX: given that, for an outlier, we'll be working with the
- # event's *claimed* auth events rather than those we calculated:
- # (a) is there any point in this test, since different_auth below will
- # obviously be empty
- # (b) alternatively, why don't we do it earlier?
- logger.info("Skipping auth_event fetch for outlier")
- return context, auth_events
-
- different_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
-
- if not different_auth:
- return context, auth_events
-
- logger.info(
- "auth_events refers to events which are not in our calculated auth "
- "chain: %s",
- different_auth,
- )
-
- # XXX: currently this checks for redactions but I'm not convinced that is
- # necessary?
- different_events = await self.store.get_events_as_list(different_auth)
-
- for d in different_events:
- if d.room_id != event.room_id:
- logger.warning(
- "Event %s refers to auth_event %s which is in a different room",
- event.event_id,
- d.event_id,
- )
-
- # don't attempt to resolve the claimed auth events against our own
- # in this case: just use our own auth events.
- #
- # XXX: should we reject the event in this case? It feels like we should,
- # but then shouldn't we also do so if we've failed to fetch any of the
- # auth events?
- return context, auth_events
-
- # now we state-resolve between our own idea of the auth events, and the remote's
- # idea of them.
-
- local_state = auth_events.values()
- remote_auth_events = dict(auth_events)
- remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
- remote_state = remote_auth_events.values()
-
- room_version = await self.store.get_room_version_id(event.room_id)
- new_state = await self.state_handler.resolve_events(
- room_version, (local_state, remote_state), event
- )
-
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id
- for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
-
- auth_events.update(new_state)
-
- context = await self._update_context_for_auth_events(
- event, context, auth_events
- )
-
- return context, auth_events
-
- async def _update_context_for_auth_events(
- self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
- ) -> EventContext:
- """Update the state_ids in an event context after auth event resolution,
- storing the changes as a new state group.
-
- Args:
- event: The event we're handling the context for
-
- context: initial event context
-
- auth_events: Events to update in the event context.
-
- Returns:
- new event context
- """
- # exclude the state key of the new event from the current_state in the context.
- if event.is_state():
- event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
- else:
- event_key = None
- state_updates = {
- k: a.event_id for k, a in auth_events.items() if k != event_key
- }
-
- current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids) # type: ignore
-
- current_state_ids.update(state_updates)
-
- prev_state_ids = await context.get_prev_state_ids()
- prev_state_ids = dict(prev_state_ids)
-
- prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
- # create a new state group as a delta from the existing one.
- prev_group = context.state_group
- state_group = await self.state_store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=state_updates,
- current_state_ids=current_state_ids,
- )
-
- return EventContext.with_state(
- state_group=state_group,
- state_group_before_event=context.state_group_before_event,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=prev_group,
- delta_ids=state_updates,
- )
-
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
@@ -3256,99 +1606,6 @@ class FederationHandler(BaseHandler):
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
- async def persist_events_and_notify(
- self,
- room_id: str,
- event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
- backfilled: bool = False,
- ) -> int:
- """Persists events and tells the notifier/pushers about them, if
- necessary.
-
- Args:
- room_id: The room ID of events being persisted.
- event_and_contexts: Sequence of events with their associated
- context that should be persisted. All events must belong to
- the same room.
- backfilled: Whether these events are a result of
- backfilling or not
-
- Returns:
- The stream ID after which all events have been persisted.
- """
- if not event_and_contexts:
- return self.store.get_current_events_token()
-
- instance = self.config.worker.events_shard_config.get_instance(room_id)
- if instance != self._instance_name:
- # Limit the number of events sent over replication. We choose 200
- # here as that is what we default to in `max_request_body_size(..)`
- for batch in batch_iter(event_and_contexts, 200):
- result = await self._send_events(
- instance_name=instance,
- store=self.store,
- room_id=room_id,
- event_and_contexts=batch,
- backfilled=backfilled,
- )
- return result["max_stream_id"]
- else:
- assert self.storage.persistence
-
- # Note that this returns the events that were persisted, which may not be
- # the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self.storage.persistence.persist_events(
- event_and_contexts, backfilled=backfilled
- )
-
- if self._ephemeral_messages_enabled:
- for event in events:
- # If there's an expiry timestamp on the event, schedule its expiry.
- self._message_handler.maybe_schedule_expiry(event)
-
- if not backfilled: # Never notify for backfilled events
- for event in events:
- await self._notify_persisted_event(event, max_stream_token)
-
- return max_stream_token.stream
-
- async def _notify_persisted_event(
- self, event: EventBase, max_stream_token: RoomStreamToken
- ) -> None:
- """Checks to see if notifier/pushers should be notified about the
- event or not.
-
- Args:
- event:
- max_stream_id: The max_stream_id returned by persist_events
- """
-
- extra_users = []
- if event.type == EventTypes.Member:
- target_user_id = event.state_key
-
- # We notify for memberships if its an invite for one of our
- # users
- if event.internal_metadata.is_outlier():
- if event.membership != Membership.INVITE:
- if not self.is_mine_id(target_user_id):
- return
-
- target_user = UserID.from_string(target_user_id)
- extra_users.append(target_user)
- elif event.internal_metadata.is_outlier():
- return
-
- # the event has been persisted so it should have a stream ordering.
- assert event.internal_metadata.stream_ordering
-
- event_pos = PersistedEventPosition(
- self._instance_name, event.internal_metadata.stream_ordering
- )
- self.notifier.on_new_room_event(
- event, event_pos, max_stream_token, extra_users=extra_users
- )
-
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
new file mode 100644
index 0000000000..9f055f00cf
--- /dev/null
+++ b/synapse/handlers/federation_event.py
@@ -0,0 +1,1825 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from http import HTTPStatus
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Container,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
+
+import attr
+from prometheus_client import Counter
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RejectedReason,
+ RoomEncryptionAlgorithms,
+)
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ FederationError,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers._base import BaseHandler
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.logging.utils import log_function
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.replication.http.federation import (
+ ReplicationFederationSendEventsRestServlet,
+)
+from synapse.state import StateResolutionStore
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
+from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.iterutils import batch_iter
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+soft_failed_event_counter = Counter(
+ "synapse_federation_soft_failed_events_total",
+ "Events received over federation that we marked as soft_failed",
+)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _auth_and_persist_events
+
+ Attributes:
+ event: the received event
+
+ claimed_auth_event_map: a map of (type, state_key) => event for the event's
+ claimed auth_events.
+
+ This can include events which have not yet been persisted, in the case that
+ we are backfilling a batch of events.
+
+ Note: May be incomplete: if we were unable to find all of the claimed auth
+ events. Also, treat the contents with caution: the events might also have
+ been rejected, might not yet have been authorized themselves, or they might
+ be in the wrong room.
+
+ """
+
+ event: EventBase
+ claimed_auth_event_map: StateMap[EventBase]
+
+
+class FederationEventHandler(BaseHandler):
+ """Handles events that originated from federation.
+
+ Responsible for handing incoming events and passing them on to the rest
+ of the homeserver (including auth and state conflict resolutions)
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
+ self.state_handler = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self._event_auth_handler = hs.get_event_auth_handler()
+ self._message_handler = hs.get_message_handler()
+ self.action_generator = hs.get_action_generator()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ self.federation_client = hs.get_federation_client()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
+ self.is_mine_id = hs.is_mine_id
+ self._instance_name = hs.get_instance_name()
+
+ self.config = hs.config
+ self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
+
+ self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
+ if hs.config.worker_app:
+ self._user_device_resync = (
+ ReplicationUserDevicesResyncRestServlet.make_client(hs)
+ )
+ else:
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
+ # When joining a room we need to queue any events for that room up.
+ # For each room, a list of (pdu, origin) tuples.
+ # TODO: replace this with something more elegant, probably based around the
+ # federation event staging area.
+ self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
+
+ self._room_pdu_linearizer = Linearizer("fed_room_pdu")
+
+ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+ """Process a PDU received via a federation /send/ transaction
+
+ Args:
+ origin: server which initiated the /send/ transaction. Will
+ be used to fetch missing events or state.
+ pdu: received PDU
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ # We reprocess pdus when we have seen them only as outliers
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen", event_id
+ )
+ return
+ if pdu.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received outlier %s which we already have as an outlier",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ logger.warning("Received event failed sanity checks")
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
+
+ # If we are currently in the process of joining this room, then we
+ # queue up events for later processing.
+ if room_id in self.room_queues:
+ logger.info(
+ "Queuing PDU from %s for now: join in progress",
+ origin,
+ )
+ self.room_queues[room_id].append((pdu, origin))
+ return
+
+ # If we're not in the room just ditch the event entirely. This is
+ # probably an old server that has come back and thinks we're still in
+ # the room (or we've been rejoined to the room by a state reset).
+ #
+ # Note that if we were never in the room then we would have already
+ # dropped the event, since we wouldn't know the room version.
+ is_in_room = await self._event_auth_handler.check_host_in_room(
+ room_id, self.server_name
+ )
+ if not is_in_room:
+ logger.info(
+ "Ignoring PDU from %s as we're not in the room",
+ origin,
+ )
+ return None
+
+ # Check that the event passes auth based on the state at the event. This is
+ # done for events that are to be added to the timeline (non-outliers).
+ #
+ # Get missing pdus if necessary:
+ # - Fetching any missing prev events to fill in gaps in the graph
+ # - Fetching state if we have a hole in the graph
+ if not pdu.internal_metadata.is_outlier():
+ prevs = set(pdu.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if missing_prevs:
+ # We only backfill backwards to the min depth.
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
+ logger.debug("min_depth: %d", min_depth)
+
+ if min_depth is not None and pdu.depth > min_depth:
+ # If we're missing stuff, ensure we only fetch stuff one
+ # at a time.
+ logger.info(
+ "Acquiring room lock to fetch %d missing prev_events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
+ logger.info(
+ "Acquired room lock to fetch %d missing prev_events",
+ len(missing_prevs),
+ )
+
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ ) from e
+
+ # Update the set of things we've seen after trying to
+ # fetch the missing stuff
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ logger.info("Found all missing prev_events")
+
+ if missing_prevs:
+ # since this event was pushed to us, it is possible for it to
+ # become the only forward-extremity in the room, and we would then
+ # trust its state to be the state for the whole room. This is very
+ # bad. Further, if the event was pushed to us, there is no excuse
+ # for us not to have all the prev_events. (XXX: apart from
+ # min_depth?)
+ #
+ # We therefore reject any such events.
+ logger.warning(
+ "Rejecting: failed to fetch %d prev events: %s",
+ len(missing_prevs),
+ shortstr(missing_prevs),
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ (
+ "Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ affected=pdu.event_id,
+ )
+
+ await self._process_received_pdu(origin, pdu, state=None)
+
+ @log_function
+ async def on_send_membership_event(
+ self, origin: str, event: EventBase
+ ) -> Tuple[EventBase, EventContext]:
+ """
+ We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+ Verify that event and send it into the room on the remote homeserver's behalf.
+
+ This is quite similar to on_receive_pdu, with the following principal
+ differences:
+ * only membership events are permitted (and only events with
+ sender==state_key -- ie, no kicks or bans)
+ * *We* send out the event on behalf of the remote server.
+ * We enforce the membership restrictions of restricted rooms.
+ * Rejected events result in an exception rather than being stored.
+
+ There are also other differences, however it is not clear if these are by
+ design or omission. In particular, we do not attempt to backfill any missing
+ prev_events.
+
+ Args:
+ origin: The homeserver of the remote (joining/invited/knocking) user.
+ event: The member event that has been signed by the remote homeserver.
+
+ Returns:
+ The event and context of the event after inserting it into the room graph.
+
+ Raises:
+ SynapseError if the event is not accepted into the room
+ """
+ logger.debug(
+ "on_send_membership_event: Got event: %s, signatures: %s",
+ event.event_id,
+ event.signatures,
+ )
+
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got send_membership request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ if event.sender != event.state_key:
+ raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
+
+ assert not event.internal_metadata.outlier
+
+ # Send this event on behalf of the other server.
+ #
+ # The remote server isn't a full participant in the room at this point, so
+ # may not have an up-to-date list of the other homeservers participating in
+ # the room, so we send it on their behalf.
+ event.internal_metadata.send_on_behalf_of = origin
+
+ context = await self.state_handler.compute_event_context(event)
+ context = await self._check_event_auth(origin, event, context)
+ if context.rejected:
+ raise SynapseError(
+ 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
+ )
+
+ # for joins, we need to check the restrictions of restricted rooms
+ if event.membership == Membership.JOIN:
+ await self.check_join_restrictions(context, event)
+
+ # for knock events, we run the third-party event rules. It's not entirely clear
+ # why we don't do this for other sorts of membership events.
+ if event.membership == Membership.KNOCK:
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Sending of knock %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
+ # all looks good, we can persist the event.
+ await self._run_push_actions_and_persist_event(event, context)
+ return event, context
+
+ async def check_join_restrictions(
+ self, context: EventContext, event: EventBase
+ ) -> None:
+ """Check that restrictions in restricted join rules are matched
+
+ Called when we receive a join event via send_join.
+
+ Raises an auth error if the restrictions are not matched.
+ """
+ prev_state_ids = await context.get_prev_state_ids()
+
+ # Check if the user is already in the room or invited to the room.
+ user_id = event.state_key
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+
+ # Check if the member should be allowed access via membership in a space.
+ await self._event_auth_handler.check_restricted_join_rules(
+ prev_state_ids,
+ event.room_version,
+ user_id,
+ prev_member_event,
+ )
+
+ @log_function
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> None:
+ """Trigger a backfill request to `dest` for the given `room_id`
+
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ We might also raise an InvalidResponseError if the response from the remote
+ server is just bogus.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
+ """
+ if dest == self.server_name:
+ raise SynapseError(400, "Can't backfill from self.")
+
+ events = await self.federation_client.backfill(
+ dest, room_id, limit=limit, extremities=extremities
+ )
+
+ if not events:
+ return
+
+ # if there are any events in the wrong room, the remote server is buggy and
+ # should not be trusted.
+ for ev in events:
+ if ev.room_id != room_id:
+ raise InvalidResponseError(
+ f"Remote server {dest} returned event {ev.event_id} which is in "
+ f"room {ev.room_id}, when we were backfilling in {room_id}"
+ )
+
+ await self._process_pulled_events(dest, events, backfilled=True)
+
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
+ """
+ Args:
+ origin: Origin of the pdu. Will be called to get the missing events
+ pdu: received pdu
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
+ """
+
+ room_id = pdu.room_id
+ event_id = pdu.event_id
+
+ seen = await self.store.have_events_in_timeline(prevs)
+
+ if not prevs - seen:
+ return
+
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
+
+ # We add the prev events that we have seen to the latest
+ # list to ensure the remote server doesn't give them to us
+ latest = set(latest_list)
+ latest |= seen
+
+ logger.info(
+ "Requesting missing events between %s and %s",
+ shortstr(latest),
+ event_id,
+ )
+
+ # XXX: we set timeout to 10s to help workaround
+ # https://github.com/matrix-org/synapse/issues/1733.
+ # The reason is to avoid holding the linearizer lock
+ # whilst processing inbound /send transactions, causing
+ # FDs to stack up and block other inbound transactions
+ # which empirically can currently take up to 30 minutes.
+ #
+ # N.B. this explicitly disables retry attempts.
+ #
+ # N.B. this also increases our chances of falling back to
+ # fetching fresh state for the room if the missing event
+ # can't be found, which slightly reduces our security.
+ # it may also increase our DAG extremity count for the room,
+ # causing additional state resolution? See #1760.
+ # However, fetching state doesn't hold the linearizer lock
+ # apparently.
+ #
+ # see https://github.com/matrix-org/synapse/pull/1744
+ #
+ # ----
+ #
+ # Update richvdh 2018/09/18: There are a number of problems with timing this
+ # request out aggressively on the client side:
+ #
+ # - it plays badly with the server-side rate-limiter, which starts tarpitting you
+ # if you send too many requests at once, so you end up with the server carefully
+ # working through the backlog of your requests, which you have already timed
+ # out.
+ #
+ # - for this request in particular, we now (as of
+ # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the
+ # server can't produce a plausible-looking set of prev_events - so we becone
+ # much more likely to reject the event.
+ #
+ # - contrary to what it says above, we do *not* fall back to fetching fresh state
+ # for the room if get_missing_events times out. Rather, we give up processing
+ # the PDU whose prevs we are missing, which then makes it much more likely that
+ # we'll end up back here for the *next* PDU in the list, which exacerbates the
+ # problem.
+ #
+ # - the aggressive 10s timeout was introduced to deal with incoming federation
+ # requests taking 8 hours to process. It's not entirely clear why that was going
+ # on; certainly there were other issues causing traffic storms which are now
+ # resolved, and I think in any case we may be more sensible about our locking
+ # now. We're *certainly* more sensible about our logging.
+ #
+ # All that said: Let's try increasing the timeout to 60s and see what happens.
+
+ try:
+ missing_events = await self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warning("Failed to get prev_events: %s", e)
+ return
+
+ logger.info("Got %d prev_events", len(missing_events))
+ await self._process_pulled_events(origin, missing_events, backfilled=False)
+
+ async def _process_pulled_events(
+ self, origin: str, events: Iterable[EventBase], backfilled: bool
+ ) -> None:
+ """Process a batch of events we have pulled from a remote server
+
+ Pulls in any events required to auth the events, persists the received events,
+ and notifies clients, if appropriate.
+
+ Assumes the events have already had their signatures and hashes checked.
+
+ Params:
+ origin: The server we received these events from
+ events: The received events.
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+
+ # We want to sort these by depth so we process them and
+ # tell clients about them in order.
+ sorted_events = sorted(events, key=lambda x: x.depth)
+
+ for ev in sorted_events:
+ with nested_logging_context(ev.event_id):
+ await self._process_pulled_event(origin, ev, backfilled=backfilled)
+
+ async def _process_pulled_event(
+ self, origin: str, event: EventBase, backfilled: bool
+ ) -> None:
+ """Process a single event that we have pulled from a remote server
+
+ Pulls in any events required to auth the event, persists the received event,
+ and notifies clients, if appropriate.
+
+ Assumes the event has already had its signatures and hashes checked.
+
+ This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+ logic in the case that we are missing prev_events (in particular, it just
+ requests the state at that point, rather than triggering a get_missing_events) -
+ so is appropriate when we have pulled the event from a remote server, rather
+ than having it pushed to us.
+
+ Params:
+ origin: The server we received this event from
+ events: The received event
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.info("Processing pulled event %s", event)
+
+ # these should not be outliers.
+ assert not event.internal_metadata.is_outlier()
+
+ event_id = event.event_id
+
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
+ )
+ if existing:
+ if not existing.internal_metadata.is_outlier():
+ logger.info(
+ "Ignoring received event %s which we have already seen",
+ event_id,
+ )
+ return
+ logger.info("De-outliering event %s", event_id)
+
+ try:
+ self._sanity_check_event(event)
+ except SynapseError as err:
+ logger.warning("Event %s failed sanity check: %s", event_id, err)
+ return
+
+ try:
+ state = await self._resolve_state_at_missing_prevs(origin, event)
+ await self._process_received_pdu(
+ origin, event, state=state, backfilled=backfilled
+ )
+ except FederationError as e:
+ if e.code == 403:
+ logger.warning("Pulled event %s failed history check.", event_id)
+ else:
+ raise
+
+ async def _resolve_state_at_missing_prevs(
+ self, dest: str, event: EventBase
+ ) -> Optional[Iterable[EventBase]]:
+ """Calculate the state at an event with missing prev_events.
+
+ This is used when we have pulled a batch of events from a remote server, and
+ still don't have all the prev_events.
+
+ If we already have all the prev_events for `event`, this method does nothing.
+
+ Otherwise, the missing prevs become new backwards extremities, and we fall back
+ to asking the remote server for the state after each missing `prev_event`,
+ and resolving across them.
+
+ That's ok provided we then resolve the state against other bits of the DAG
+ before using it - in other words, that the received event `event` is not going
+ to become the only forwards_extremity in the room (which will ensure that you
+ can't just take over a room by sending an event, withholding its prev_events,
+ and declaring yourself to be an admin in the subsequent state request).
+
+ In other words: we should only call this method if `event` has been *pulled*
+ as part of a batch of missing prev events, or similar.
+
+ Params:
+ dest: the remote server to ask for state at the missing prevs. Typically,
+ this will be the server we got `event` from.
+ event: an event to check for missing prevs.
+
+ Returns:
+ if we already had all the prev events, `None`. Otherwise, returns a list of
+ the events in the state at `event`.
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ prevs = set(event.prev_event_ids())
+ seen = await self.store.have_events_in_timeline(prevs)
+ missing_prevs = prevs - seen
+
+ if not missing_prevs:
+ return None
+
+ logger.info(
+ "Event %s is missing prev_events %s: calculating state for a "
+ "backwards extremity",
+ event_id,
+ shortstr(missing_prevs),
+ )
+ # Calculate the state after each of the previous events, and
+ # resolve them to find the correct state at the current event.
+ event_map = {event_id: event}
+ try:
+ # Get the state of the events we know about
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+ # state_maps is a list of mappings from (type, state_key) to event_id
+ state_maps: List[StateMap[str]] = list(ours.values())
+
+ # we don't need this any more, let's delete it.
+ del ours
+
+ # Ask the remote server for the states we don't
+ # know about
+ for p in missing_prevs:
+ logger.info("Requesting state after missing prev_event %s", p)
+
+ with nested_logging_context(p):
+ # note that if any of the missing prevs share missing state or
+ # auth events, the requests to fetch those events are deduped
+ # by the get_pdu_cache in federation_client.
+ remote_state = await self._get_state_after_missing_prev_event(
+ dest, room_id, p
+ )
+
+ remote_state_map = {
+ (x.type, x.state_key): x.event_id for x in remote_state
+ }
+ state_maps.append(remote_state_map)
+
+ for x in remote_state:
+ event_map[x.event_id] = x
+
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
+ state_res_store=StateResolutionStore(self.store),
+ )
+
+ # We need to give _process_received_pdu the actual state events
+ # rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = await self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ )
+ event_map.update(evs)
+
+ state = [event_map[e] for e in state_map.values()]
+ except Exception:
+ logger.warning(
+ "Error attempting to resolve state at missing prev_events",
+ exc_info=True,
+ )
+ raise FederationError(
+ "ERROR",
+ 403,
+ "We can't get valid state history.",
+ affected=event_id,
+ )
+ return state
+
+ async def _get_state_after_missing_prev_event(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ ) -> List[EventBase]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+
+ Returns:
+ A list of events in the state, including the event itself
+ """
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
+
+ logger.debug(
+ "state_ids returned %i state events, %i auth events",
+ len(state_event_ids),
+ len(auth_event_ids),
+ )
+
+ # start by just trying to fetch the events from the store
+ desired_events = set(state_event_ids)
+ desired_events.add(event_id)
+ logger.debug("Fetching %i events from cache/store", len(desired_events))
+ fetched_events = await self.store.get_events(
+ desired_events, allow_rejected=True
+ )
+
+ missing_desired_events = desired_events - fetched_events.keys()
+ logger.debug(
+ "We are missing %i events (got %i)",
+ len(missing_desired_events),
+ len(fetched_events),
+ )
+
+ # We probably won't need most of the auth events, so let's just check which
+ # we have for now, rather than thrashing the event cache with them all
+ # unnecessarily.
+
+ # TODO: we probably won't actually need all of the auth events, since we
+ # already have a bunch of the state events. It would be nice if the
+ # federation api gave us a way of finding out which we actually need.
+
+ missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+ missing_auth_events.difference_update(
+ await self.store.have_seen_events(room_id, missing_auth_events)
+ )
+ logger.debug("We are also missing %i auth events", len(missing_auth_events))
+
+ missing_events = missing_desired_events | missing_auth_events
+ logger.debug("Fetching %i events from remote", len(missing_events))
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
+ )
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ await self.store.get_events(missing_desired_events, allow_rejected=True)
+ )
+
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
+
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ # if we couldn't get the prev event in question, that's a problem.
+ remote_event = fetched_events.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+
+ # missing state at that event is a warning, not a blocker
+ # XXX: this doesn't sound right? it means that we'll end up with incomplete
+ # state.
+ failed_to_fetch = desired_events - fetched_events.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
+ ]
+
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
+
+ return remote_state
+
+ async def _process_received_pdu(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool = False,
+ ) -> None:
+ """Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+
+ backfilled: True if this is part of a historical batch of events (inhibits
+ notification to clients, and validation of device keys.)
+ """
+ logger.debug("Processing event: %s", event)
+
+ try:
+ context = await self.state_handler.compute_event_context(
+ event, old_state=state
+ )
+ await self._auth_and_persist_event(
+ origin, event, context, state=state, backfilled=backfilled
+ )
+ except AuthError as e:
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
+
+ if backfilled:
+ return
+
+ # For encrypted messages we check that we know about the sending device,
+ # if we don't then we mark the device cache for that user as stale.
+ if event.type == EventTypes.Encrypted:
+ device_id = event.content.get("device_id")
+ sender_key = event.content.get("sender_key")
+
+ cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+
+ resync = False # Whether we should resync device lists.
+
+ device = None
+ if device_id is not None:
+ device = cached_devices.get(device_id)
+ if device is None:
+ logger.info(
+ "Received event from remote device not in our cache: %s %s",
+ event.sender,
+ device_id,
+ )
+ resync = True
+
+ # We also check if the `sender_key` matches what we expect.
+ if sender_key is not None:
+ # Figure out what sender key we're expecting. If we know the
+ # device and recognize the algorithm then we can work out the
+ # exact key to expect. Otherwise check it matches any key we
+ # have for that device.
+
+ current_keys: Container[str] = []
+
+ if device:
+ keys = device.get("keys", {}).get("keys", {})
+
+ if (
+ event.content.get("algorithm")
+ == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2
+ ):
+ # For this algorithm we expect a curve25519 key.
+ key_name = "curve25519:%s" % (device_id,)
+ current_keys = [keys.get(key_name)]
+ else:
+ # We don't know understand the algorithm, so we just
+ # check it matches a key for the device.
+ current_keys = keys.values()
+ elif device_id:
+ # We don't have any keys for the device ID.
+ pass
+ else:
+ # The event didn't include a device ID, so we just look for
+ # keys across all devices.
+ current_keys = [
+ key
+ for device in cached_devices.values()
+ for key in device.get("keys", {}).get("keys", {}).values()
+ ]
+
+ # We now check that the sender key matches (one of) the expected
+ # keys.
+ if sender_key not in current_keys:
+ logger.info(
+ "Received event from remote device with unexpected sender key: %s %s: %s",
+ event.sender,
+ device_id or "",
+ sender_key,
+ )
+ resync = True
+
+ if resync:
+ run_as_background_process(
+ "resync_device_due_to_pdu",
+ self._resync_device,
+ event.sender,
+ )
+
+ await self._handle_marker_event(origin, event)
+
+ async def _resync_device(self, sender: str) -> None:
+ """We have detected that the device list for the given user may be out
+ of sync, so we try and resync them.
+ """
+
+ try:
+ await self.store.mark_remote_user_device_cache_as_stale(sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ await self._user_device_resync(user_id=sender)
+ else:
+ await self._device_list_updater.user_device_resync(sender)
+ except Exception:
+ logger.exception("Failed to resync device for %s", sender)
+
+ async def _handle_marker_event(self, origin: str, marker_event: EventBase):
+ """Handles backfilling the insertion event when we receive a marker
+ event that points to one.
+
+ Args:
+ origin: Origin of the event. Will be called to get the insertion event
+ marker_event: The event to process
+ """
+
+ if marker_event.type != EventTypes.MSC2716_MARKER:
+ # Not a marker event
+ return
+
+ if marker_event.rejected_reason is not None:
+ # Rejected event
+ return
+
+ # Skip processing a marker event if the room version doesn't
+ # support it.
+ room_version = await self.store.get_room_version(marker_event.room_id)
+ if not room_version.msc2716_historical:
+ return
+
+ logger.debug("_handle_marker_event: received %s", marker_event)
+
+ insertion_event_id = marker_event.content.get(
+ EventContentFields.MSC2716_MARKER_INSERTION
+ )
+
+ if insertion_event_id is None:
+ # Nothing to retrieve then (invalid marker)
+ return
+
+ logger.debug(
+ "_handle_marker_event: backfilling insertion event %s", insertion_event_id
+ )
+
+ await self._get_events_and_persist(
+ origin,
+ marker_event.room_id,
+ [insertion_event_id],
+ )
+
+ insertion_event = await self.store.get_event(
+ insertion_event_id, allow_none=True
+ )
+ if insertion_event is None:
+ logger.warning(
+ "_handle_marker_event: server %s didn't return insertion event %s for marker %s",
+ origin,
+ insertion_event_id,
+ marker_event.event_id,
+ )
+ return
+
+ logger.debug(
+ "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ await self.store.insert_insertion_extremity(
+ insertion_event_id, marker_event.room_id
+ )
+
+ logger.debug(
+ "_handle_marker_event: insertion extremity added for %s from marker event %s",
+ insertion_event,
+ marker_event,
+ )
+
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ) -> None:
+ """Fetch the given events from a server, and persist them as outliers.
+
+ This function *does not* recursively get missing auth events of the
+ newly fetched events. Callers must include in the `events` argument
+ any missing events from the auth chain.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_map: Dict[str, EventBase] = {}
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination],
+ event_id,
+ room_version,
+ outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s",
+ destination,
+ event_id,
+ )
+ return
+
+ event_map[event.event_id] = event
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ # Make a map of auth events for each event. We do this after fetching
+ # all the events as some of the events' auth events will be in the list
+ # of requested events.
+
+ auth_events = [
+ aid
+ for event in event_map.values()
+ for aid in event.auth_event_ids()
+ if aid not in event_map
+ ]
+ persisted_events = await self.store.get_events(
+ auth_events,
+ allow_rejected=True,
+ )
+
+ event_infos = []
+ for event in event_map.values():
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+ else:
+ logger.info("Missing auth event %s", auth_event_id)
+
+ event_infos.append(_NewEventInfo(event, auth))
+
+ if event_infos:
+ await self._auth_and_persist_events(
+ destination,
+ room_id,
+ event_infos,
+ )
+
+ async def _auth_and_persist_events(
+ self,
+ origin: str,
+ room_id: str,
+ event_infos: Collection[_NewEventInfo],
+ ) -> None:
+ """Creates the appropriate contexts and persists events. The events
+ should not depend on one another, e.g. this should be used to persist
+ a bunch of outliers, but not a chunk of individual events that depend
+ on each other for state calculations.
+
+ Notifies about the events where appropriate.
+ """
+
+ if not event_infos:
+ return
+
+ async def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
+ with nested_logging_context(suffix=event.event_id):
+ res = await self.state_handler.compute_event_context(event)
+ res = await self._check_event_auth(
+ origin,
+ event,
+ res,
+ claimed_auth_event_map=ev_info.claimed_auth_event_map,
+ )
+ return res
+
+ contexts = await make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(prep, ev_info) for ev_info in event_infos],
+ consumeErrors=True,
+ )
+ )
+
+ await self.persist_events_and_notify(
+ room_id,
+ [
+ (ev_info.event, context)
+ for ev_info, context in zip(event_infos, contexts)
+ ],
+ )
+
+ async def _auth_and_persist_event(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> None:
+ """
+ Process an event by performing auth checks and then persisting to the database.
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers.
+
+ backfilled: True if the event was backfilled.
+ """
+ context = await self._check_event_auth(
+ origin,
+ event,
+ context,
+ state=state,
+ claimed_auth_event_map=claimed_auth_event_map,
+ backfilled=backfilled,
+ )
+
+ await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+ async def _check_event_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ state: Optional[Iterable[EventBase]] = None,
+ claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
+ """
+ Checks whether an event should be rejected (for failing auth checks).
+
+ Args:
+ origin: The host the event originates from.
+ event: The event itself.
+ context:
+ The event context.
+
+ state:
+ The state events used to check the event for soft-fail. If this is
+ not provided the current state events will be used.
+
+ claimed_auth_event_map:
+ A map of (type, state_key) => event for the event's claimed auth_events.
+ Possibly incomplete, and possibly including events that are not yet
+ persisted, or authed, or in the right room.
+
+ Only populated where we may not already have persisted these events -
+ for example, when populating outliers, or the state for a backwards
+ extremity.
+
+ backfilled: True if the event was backfilled.
+
+ Returns:
+ The updated context object.
+ """
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ if claimed_auth_event_map:
+ # if we have a copy of the auth events from the event, use that as the
+ # basis for auth.
+ auth_events = claimed_auth_event_map
+ else:
+ # otherwise, we calculate what the auth events *should* be, and use that
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self._event_auth_handler.compute_auth_events(
+ event, prev_state_ids, for_verification=True
+ )
+ auth_events_x = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
+
+ try:
+ (
+ context,
+ auth_events_for_auth,
+ ) = await self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ except Exception:
+ # We don't really mind if the above fails, so lets not fail
+ # processing if it does. However, it really shouldn't fail so
+ # let's still log as an exception since we'll still want to fix
+ # any bugs.
+ logger.exception(
+ "Failed to double check auth events for %s with remote. "
+ "Ignoring failure and continuing processing of event.",
+ event.event_id,
+ )
+ auth_events_for_auth = auth_events
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
+ except AuthError as e:
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
+
+ if not context.rejected:
+ await self._check_for_soft_fail(event, state, backfilled, origin=origin)
+
+ if event.type == EventTypes.GuestAccess and not context.rejected:
+ await self.maybe_kick_guest_users(event)
+
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(
+ event, context
+ )
+
+ return context
+
+ async def _check_for_soft_fail(
+ self,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ backfilled: bool,
+ origin: str,
+ ) -> None:
+ """Checks if we should soft fail the event; if so, marks the event as
+ such.
+
+ Args:
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
+ origin: The host the event originates from.
+ """
+ # For new (non-backfilled and non-outlier) events we check if the event
+ # passes auth based on the current state. If it doesn't then we
+ # "soft-fail" the event.
+ if backfilled or event.internal_metadata.is_outlier():
+ return
+
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
+ prev_event_ids = set(event.prev_event_ids())
+
+ if extrem_ids == prev_event_ids:
+ # If they're the same then the current state is the same as the
+ # state at the event, so no point rechecking auth for soft fail.
+ return
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ # Calculate the "current state".
+ if state is not None:
+ # If we're explicitly given the state then we won't have all the
+ # prev events, and so we have a gap in the graph. In this case
+ # we want to be a little careful as we might have been down for
+ # a while and have an incorrect view of the current state,
+ # however we still want to do checks as gaps are easy to
+ # maliciously manufacture.
+ #
+ # So we use a "current state" that is actually a state
+ # resolution across the current forward extremities and the
+ # given state at the event. This should correctly handle cases
+ # like bans, especially with state res v2.
+
+ state_sets_d = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
+ )
+ state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
+ state_sets.append(state)
+ current_states = await self.state_handler.resolve_events(
+ room_version, state_sets, event
+ )
+ current_state_ids: StateMap[str] = {
+ k: e.event_id for k, e in current_states.items()
+ }
+ else:
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
+ )
+
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id,
+ current_state_ids,
+ )
+
+ # Now check if event pass auth against said current state
+ auth_types = auth_types_for_event(room_version_obj, event)
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
+
+ auth_events_map = await self.store.get_events(current_state_ids_list)
+ current_auth_events = {
+ (e.type, e.state_key): e for e in auth_events_map.values()
+ }
+
+ try:
+ event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+ except AuthError as e:
+ logger.warning(
+ "Soft-failing %r (from %s) because %s",
+ event,
+ e,
+ origin,
+ extra={
+ "room_id": event.room_id,
+ "mxid": event.sender,
+ "hs": origin,
+ },
+ )
+ soft_failed_event_counter.inc()
+ event.internal_metadata.soft_failed = True
+
+ async def _update_auth_events_and_context_for_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ input_auth_events: StateMap[EventBase],
+ ) -> Tuple[EventContext, StateMap[EventBase]]:
+ """Helper for _check_event_auth. See there for docs.
+
+ Checks whether a given event has the expected auth events. If it
+ doesn't then we talk to the remote server to compare state to see if
+ we can come to a consensus (e.g. if one server missed some valid
+ state).
+
+ This attempts to resolve any potential divergence of state between
+ servers, but is not essential and so failures should not block further
+ processing of the event.
+
+ Args:
+ origin:
+ event:
+ context:
+
+ input_auth_events:
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Returns:
+ updated context, updated auth event map
+ """
+ # take a copy of input_auth_events before we modify it.
+ auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
+
+ event_auth_events = set(event.auth_event_ids())
+
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
+ if missing_auth:
+ have_events = await self.store.have_seen_events(event.room_id, missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
+
+ if missing_auth:
+ # If we don't have all the auth events, we need to get them.
+ logger.info("auth_events contains unknown events: %s", missing_auth)
+ try:
+ try:
+ remote_auth_chain = await self.federation_client.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+ except RequestSendFailed as e1:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e1)
+ return context, auth_events
+
+ seen_remotes = await self.store.have_seen_events(
+ event.room_id, [e.event_id for e in remote_auth_chain]
+ )
+
+ for e in remote_auth_chain:
+ if e.event_id in seen_remotes:
+ continue
+
+ if e.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = e.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
+ if e.event_id in auth_ids or e.type == EventTypes.Create
+ }
+ e.internal_metadata.outlier = True
+
+ logger.debug(
+ "_check_event_auth %s missing_auth: %s",
+ event.event_id,
+ e.event_id,
+ )
+ missing_auth_event_context = (
+ await self.state_handler.compute_event_context(e)
+ )
+ await self._auth_and_persist_event(
+ origin,
+ e,
+ missing_auth_event_context,
+ claimed_auth_event_map=auth,
+ )
+
+ if e.event_id in event_auth_events:
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ except Exception:
+ logger.exception("Failed to get auth chain")
+
+ if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
+ logger.info("Skipping auth_event fetch for outlier")
+ return context, auth_events
+
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if not different_auth:
+ return context, auth_events
+
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
+
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = await self.store.get_events_as_list(different_auth)
+
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
+ )
+
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context, auth_events
+
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
+
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
+
+ room_version = await self.store.get_room_version_id(event.room_id)
+ new_state = await self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
+
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
+
+ auth_events.update(new_state)
+
+ context = await self._update_context_for_auth_events(
+ event, context, auth_events
+ )
+
+ return context, auth_events
+
+ async def _update_context_for_auth_events(
+ self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+ ) -> EventContext:
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
+
+ Args:
+ event: The event we're handling the context for
+
+ context: initial event context
+
+ auth_events: Events to update in the event context.
+
+ Returns:
+ new event context
+ """
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
+ else:
+ event_key = None
+ state_updates = {
+ k: a.event_id for k, a in auth_events.items() if k != event_key
+ }
+
+ current_state_ids = await context.get_current_state_ids()
+ current_state_ids = dict(current_state_ids) # type: ignore
+
+ current_state_ids.update(state_updates)
+
+ prev_state_ids = await context.get_prev_state_ids()
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
+
+ # create a new state group as a delta from the existing one.
+ prev_group = context.state_group
+ state_group = await self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ current_state_ids=current_state_ids,
+ )
+
+ return EventContext.with_state(
+ state_group=state_group,
+ state_group_before_event=context.state_group_before_event,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ prev_group=prev_group,
+ delta_ids=state_updates,
+ )
+
+ async def _run_push_actions_and_persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ):
+ """Run the push actions for a received event, and persist it.
+
+ Args:
+ event: The event itself.
+ context: The event context.
+ backfilled: True if the event was backfilled.
+ """
+ try:
+ if (
+ not event.internal_metadata.is_outlier()
+ and not backfilled
+ and not context.rejected
+ and (await self.store.get_min_depth(event.room_id)) <= event.depth
+ ):
+ await self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
+
+ await self.persist_events_and_notify(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+ except Exception:
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
+ )
+ raise
+
+ async def persist_events_and_notify(
+ self,
+ room_id: str,
+ event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> int:
+ """Persists events and tells the notifier/pushers about them, if
+ necessary.
+
+ Args:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
+ backfilled: Whether these events are a result of
+ backfilling or not
+
+ Returns:
+ The stream ID after which all events have been persisted.
+ """
+ if not event_and_contexts:
+ return self.store.get_current_events_token()
+
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
+ # Limit the number of events sent over replication. We choose 200
+ # here as that is what we default to in `max_request_body_size(..)`
+ for batch in batch_iter(event_and_contexts, 200):
+ result = await self._send_events(
+ instance_name=instance,
+ store=self.store,
+ room_id=room_id,
+ event_and_contexts=batch,
+ backfilled=backfilled,
+ )
+ return result["max_stream_id"]
+ else:
+ assert self.storage.persistence
+
+ # Note that this returns the events that were persisted, which may not be
+ # the same as were passed in if some were deduplicated due to transaction IDs.
+ events, max_stream_token = await self.storage.persistence.persist_events(
+ event_and_contexts, backfilled=backfilled
+ )
+
+ if self._ephemeral_messages_enabled:
+ for event in events:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
+ if not backfilled: # Never notify for backfilled events
+ for event in events:
+ await self._notify_persisted_event(event, max_stream_token)
+
+ return max_stream_token.stream
+
+ async def _notify_persisted_event(
+ self, event: EventBase, max_stream_token: RoomStreamToken
+ ) -> None:
+ """Checks to see if notifier/pushers should be notified about the
+ event or not.
+
+ Args:
+ event:
+ max_stream_token: The max_stream_id returned by persist_events
+ """
+
+ extra_users = []
+ if event.type == EventTypes.Member:
+ target_user_id = event.state_key
+
+ # We notify for memberships if its an invite for one of our
+ # users
+ if event.internal_metadata.is_outlier():
+ if event.membership != Membership.INVITE:
+ if not self.is_mine_id(target_user_id):
+ return
+
+ target_user = UserID.from_string(target_user_id)
+ extra_users.append(target_user)
+ elif event.internal_metadata.is_outlier():
+ return
+
+ # the event has been persisted so it should have a stream ordering.
+ assert event.internal_metadata.stream_ordering
+
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
+ self.notifier.on_new_room_event(
+ event, event_pos, max_stream_token, extra_users=extra_users
+ )
+
+ def _sanity_check_event(self, ev: EventBase) -> None:
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev: event to be checked
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_event_ids()) > 20:
+ logger.warning(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events")
+
+ if len(ev.auth_event_ids()) > 10:
+ logger.warning(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
+ )
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
+
+ async def get_min_depth_for_context(self, context: str) -> int:
+ return await self.store.get_min_depth(context)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 79cadb7b57..a0b3145f4e 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.clock = hs.get_clock()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
@@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
- max_stream_id = await self.federation_handler.persist_events_and_notify(
+ max_stream_id = await self.federation_event_handler.persist_events_and_notify(
room_id, event_and_contexts, backfilled
)
diff --git a/synapse/server.py b/synapse/server.py
index de6517663e..5adeeff61a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler
+from synapse.handlers.federation_event import FederationEventHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.identity import IdentityHandler
from synapse.handlers.initial_sync import InitialSyncHandler
@@ -546,6 +547,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_federation_handler(self) -> FederationHandler:
return FederationHandler(self)
+ @cache_in_self
+ def get_federation_event_handler(self) -> FederationEventHandler:
+ return FederationEventHandler(self)
+
@cache_in_self
def get_identity_handler(self) -> IdentityHandler:
return IdentityHandler(self)
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 383214ab50..663960ff53 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -208,7 +208,7 @@ class FederationKnockingTestCase(
async def _check_event_auth(origin, event, context, *args, **kwargs):
return context
- homeserver.get_federation_handler()._check_event_auth = _check_event_auth
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
return super().prepare(reactor, clock, homeserver)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c72a8972a3..6c67a16de9 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
with LoggingContext("send_rejected"):
- d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ d = run_in_background(
+ self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
+ )
self.get_success(d)
# that should have been rejected
@@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
d = run_in_background(
- self.handler.on_receive_pdu, OTHER_SERVER, message_event
+ self.hs.get_federation_event_handler().on_receive_pdu,
+ OTHER_SERVER,
+ message_event,
)
self.get_success(d)
@@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
join_event.signatures[other_server] = {"x": "y"}
with LoggingContext("send_join"):
d = run_in_background(
- self.handler.on_send_membership_event, other_server, join_event
+ self.hs.get_federation_event_handler().on_send_membership_event,
+ other_server,
+ join_event,
)
self.get_success(d)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0a52bc8b72..671dc7d083 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -885,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_federation_handler()
+ self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -1026,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
)
- self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+ self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index af5dfca752..92a5b53e11 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
- federation = self.hs.get_federation_handler()
+ federation = self.hs.get_federation_event_handler()
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 348fcb72a7..61c9d7c2ef 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,7 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
self.handler = self.homeserver.get_federation_handler()
- self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
+ federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
context
)
self.client = self.homeserver.get_federation_client()
@@ -85,7 +86,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Send the join, it should return None (which is not an error)
self.assertEqual(
- self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
+ self.get_success(
+ federation_event_handler.on_receive_pdu("test.serv", join_event)
+ ),
None,
)
@@ -129,9 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
+ federation_event_handler = self.homeserver.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
- self.handler.on_receive_pdu("test.serv", lying_event),
+ federation_event_handler.on_receive_pdu("test.serv", lying_event),
FederationError,
)
--
cgit 1.5.1
From c4fa4f37cbc734f9cd6354a5f2661efc30d73cac Mon Sep 17 00:00:00 2001
From: Erik Johnston
Date: Fri, 27 Aug 2021 10:15:50 +0100
Subject: Fix perf of fetching the same events many times. (#10703)
The code to deduplicate repeated fetches of the same set of events was
N^2 (over the number of events requested), which could lead to a process
being completely wedged.
The main fix is to deduplicate the returned deferreds so we only await
on a deferred once rather than many times. Seperately, when handling the
returned events from the defrered we only add the events we care about
to the event map to be returned (so that we don't pay the price of
inserting extraneous events into the dict).
---
changelog.d/10703.bugfix | 1 +
synapse/storage/databases/main/events_worker.py | 29 ++++++++++++++++++++-----
2 files changed, 24 insertions(+), 6 deletions(-)
create mode 100644 changelog.d/10703.bugfix
(limited to 'synapse')
diff --git a/changelog.d/10703.bugfix b/changelog.d/10703.bugfix
new file mode 100644
index 0000000000..a5a4ecf8ee
--- /dev/null
+++ b/changelog.d/10703.bugfix
@@ -0,0 +1 @@
+Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 375463e4e9..9501f00f3b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
- already_fetching: Dict[str, defer.Deferred] = {}
+ #
+ # Note: we might get the same `ObservableDeferred` back for multiple
+ # events we're already fetching, so we deduplicate the deferreds to
+ # avoid extraneous work (if we don't do this we can end up in a n^2 mode
+ # when we wait on the same Deferred N times, then try and merge the
+ # same dict into itself N times).
+ already_fetching_ids: Set[str] = set()
+ already_fetching_deferreds: Set[
+ ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = set()
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
- already_fetching[event_id] = deferred.observe()
+ already_fetching_ids.add(event_id)
+ already_fetching_deferreds.add(deferred)
- missing_events_ids.difference_update(already_fetching)
+ missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
@@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
- if already_fetching:
+ if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
- already_fetching.values(),
+ (d.observe() for d in already_fetching_deferreds),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
- event_entry_map.update(result)
+ # We filter out events that we haven't asked for as we might get
+ # a *lot* of superfluous events back, and there is no point
+ # going through and inserting them all (which can take time).
+ event_entry_map.update(
+ (event_id, entry)
+ for event_id, entry in result.items()
+ if event_id in already_fetching_ids
+ )
if not allow_rejected:
event_entry_map = {
--
cgit 1.5.1
From e62cdbef1a499f428e48f98167b2b709d16c671d Mon Sep 17 00:00:00 2001
From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com>
Date: Fri, 27 Aug 2021 11:16:40 +0200
Subject: Improve ServerNoticeServlet to avoid duplicate requests (#10679)
Fixes: #9544
---
changelog.d/10679.bugfix | 1 +
synapse/rest/admin/__init__.py | 5 +-
synapse/rest/admin/server_notice_servlet.py | 19 +-
synapse/server_notices/server_notices_manager.py | 17 +-
tests/rest/admin/test_server_notice.py | 450 +++++++++++++++++++++++
5 files changed, 475 insertions(+), 17 deletions(-)
create mode 100644 changelog.d/10679.bugfix
create mode 100644 tests/rest/admin/test_server_notice.py
(limited to 'synapse')
diff --git a/changelog.d/10679.bugfix b/changelog.d/10679.bugfix
new file mode 100644
index 0000000000..5c4061f6d5
--- /dev/null
+++ b/changelog.d/10679.bugfix
@@ -0,0 +1 @@
+Improve ServerNoticeServlet to avoid duplicate requests and add unit tests.
\ No newline at end of file
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6e1c8736e1..b2514d9d0d 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -223,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
@@ -247,6 +246,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
NewRegistrationTokenRestServlet(hs).register(http_server)
RegistrationTokenRestServlet(hs).register(http_server)
+ # Some servlets only get registered for the main process.
+ if hs.config.worker_app is None:
+ SendServerNoticeServlet(hs).register(http_server)
+
def register_servlets_for_client_rest_resource(
hs: "HomeServer", http_server: HttpServer
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b5e4c474ef..42201afc86 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
+ self.server_notices_manager = hs.get_server_notices_manager()
+ self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
@@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet):
# We grab the server notices manager here as its initialisation has a check for worker processes,
# but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
# admin api).
- if not self.hs.get_server_notices_manager().is_enabled():
+ if not self.server_notices_manager.is_enabled():
raise SynapseError(400, "Server notices are not enabled on this server")
- user_id = body["user_id"]
- UserID.from_string(user_id)
- if not self.hs.is_mine_id(user_id):
+ target_user = UserID.from_string(body["user_id"])
+ if not self.hs.is_mine(target_user):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = await self.hs.get_server_notices_manager().send_notice(
- user_id=body["user_id"],
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ event = await self.server_notices_manager.send_notice(
+ user_id=target_user.to_string(),
type=event_type,
state_key=state_key,
event_content=body["content"],
+ txn_id=txn_id,
)
return 200, {"event_id": event.event_id}
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index f19075b760..d87a538917 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.events import EventBase
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
class ServerNoticesManager:
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
+ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._config = hs.config
self._account_data_handler = hs.get_account_data_handler()
@@ -58,6 +55,7 @@ class ServerNoticesManager:
event_content: dict,
type: str = EventTypes.Message,
state_key: Optional[str] = None,
+ txn_id: Optional[str] = None,
) -> EventBase:
"""Send a notice to the given user
@@ -68,6 +66,7 @@ class ServerNoticesManager:
event_content: content of event to send
type: type of event
is_state_event: Is the event a state event
+ txn_id: The transaction ID.
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
@@ -90,7 +89,7 @@ class ServerNoticesManager:
event_dict["state_key"] = state_key
event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False
+ requester, event_dict, ratelimit=False, txn_id=txn_id
)
return event
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
new file mode 100644
index 0000000000..fbceba3254
--- /dev/null
+++ b/tests/rest/admin/test_server_notice.py
@@ -0,0 +1,450 @@
+# Copyright 2021 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login, room, sync
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class ServerNoticeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/send_server_notice"
+
+ def test_no_auth(self):
+ """Try to send a server notice without authentication."""
+ channel = self.make_request("POST", self.url)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """If the user is not a server admin, an error is returned."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_does_not_exist(self):
+ """Tests that a lookup for a user that does not exist returns a 404"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": "@unknown_person:test", "content": ""},
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": "@unknown_person:unknown_domain",
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Server notices can only be sent to local users", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_invalid_parameter(self):
+ """If parameters are invalid, an error is returned."""
+
+ # no content, no user
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ # no content
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no body
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": ""},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'body' not in content", channel.json_body["error"])
+
+ # no msgtype
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": {"body": ""}},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+
+ def test_server_notice_disabled(self):
+ """Tests that server returns error if server notice is disabled"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Server notices are not enabled on this server", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice(self):
+ """
+ Tests that sending two server notices is successfully,
+ the server uses the same room and do not send messages twice.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token)
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # invalidate cache of server notices room_ids
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has no new invites or memberships
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 2)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ self.assertEqual(messages[1]["content"]["body"], "test msg two")
+ self.assertEqual(messages[1]["sender"], "@notices:test")
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_leave_room(self):
+ """
+ Tests that sending a server notices is successfully.
+ The user leaves the room and the second message appears
+ in a new room.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # user leaves the romm
+ self.helper.leave(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id the user gets the message
+ # in old room
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # room has the same id
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_delete_room(self):
+ """
+ Tests that the user get server notice in a new room
+ after the first server notice room was deleted.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # shut down and purge room
+ self.get_success(
+ self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
+ )
+ self.get_success(self.pagination_handler.purge_room(first_room_id))
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(first_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id it gives an error
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get message
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # second room has new ID
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ def _check_invite_and_join_status(
+ self, user_id: str, expected_invites: int, expected_memberships: int
+ ) -> RoomsForUser:
+ """Check invite and room membership status of a user.
+
+ Args
+ user_id: user to check
+ expected_invites: number of expected invites of this user
+ expected_memberships: number of expected room memberships of this user
+ Returns
+ room_ids from the rooms that the user is invited
+ """
+
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(expected_invites, len(invited_rooms))
+
+ room_ids = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(expected_memberships, len(room_ids))
+
+ return invited_rooms
+
+ def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]:
+ """
+ Do a sync and get messages of a room.
+
+ Args
+ room_id: room that contains the messages
+ token: access token of user
+
+ Returns
+ list of messages contained in the room
+ """
+ channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=token
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Get the messages
+ room = channel.json_body["rooms"]["join"][room_id]
+ messages = [
+ x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+ ]
+ return messages
--
cgit 1.5.1
From 029b7ad7b94d167b19d63a5dc777a806b0e073f3 Mon Sep 17 00:00:00 2001
From: Patrick Cloke
Date: Fri, 27 Aug 2021 07:08:02 -0400
Subject: Remove unused `compare_digest` function. (#10706)
---
changelog.d/10706.misc | 1 +
synapse/rest/client/register.py | 13 -------------
2 files changed, 1 insertion(+), 13 deletions(-)
create mode 100644 changelog.d/10706.misc
(limited to 'synapse')
diff --git a/changelog.d/10706.misc b/changelog.d/10706.misc
new file mode 100644
index 0000000000..eed4aa58d6
--- /dev/null
+++ b/changelog.d/10706.misc
@@ -0,0 +1 @@
+Remove unused `compare_digest` function.
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 2781a0ea96..7b5f49d635 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import hmac
import logging
import random
from typing import List, Union
@@ -60,18 +59,6 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
-# We ought to be using hmac.compare_digest() but on older pythons it doesn't
-# exist. It's a _really minor_ security flaw to use plain string comparison
-# because the timing attack is so obscured by all the other code here it's
-# unlikely to make much difference
-if hasattr(hmac, "compare_digest"):
- compare_digest = hmac.compare_digest
-else:
-
- def compare_digest(a, b):
- return a == b
-
-
logger = logging.getLogger(__name__)
--
cgit 1.5.1
From a4c8a2f08b735266fbbe2f259e640f00dc5e3a00 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff
Date: Tue, 31 Aug 2021 13:42:51 +0100
Subject: 1.41.1
---
CHANGES.md | 32 ++++++++++++++++++++++++++++++++
debian/changelog | 6 ++++++
synapse/__init__.py | 2 +-
3 files changed, 39 insertions(+), 1 deletion(-)
(limited to 'synapse')
diff --git a/CHANGES.md b/CHANGES.md
index f8da8771aa..fab27b874e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,35 @@
+Synapse 1.41.1 (2021-08-31)
+===========================
+
+Due to the two security issues highlighted below, server administrators are encouraged to update Synapse. We are not aware of these vulnerabilities being exploited in the wild.
+
+Security advisory
+-----------------
+
+The following issues are fixed in v1.41.1.
+
+- **[GHSA-3x4c-pq33-4w3q](https://github.com/matrix-org/synapse/security/advisories/GHSA-3x4c-pq33-4w3q) / [CVE-2021-39164](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-39164): Enumerating a private room's list of members and their display names.**
+
+ If an unauthorized user both knows the Room ID of a private room *and* that room's history visibility is set to `shared`, then they may be able to enumerate the room's members, including their display names.
+
+ The unauthorized user must be on the same homeserver as a user who is a member of the target room.
+
+ Fixed by [52c7a51cf](https://github.com/matrix-org/synapse/commit/52c7a51cf).
+
+- **[GHSA-jj53-8fmw-f2w2](https://github.com/matrix-org/synapse/security/advisories/GHSA-jj53-8fmw-f2w2) / [CVE-2021-39163](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-39163): Disclosing a private room's name, avatar, topic, and number of members.**
+
+ If an unauthorized user knows the Room ID of a private room, then its name, avatar, topic, and number of members may be disclosed through Group / Community features.
+
+ The unauthorized user must be on the same homeserver as a user who is a member of the target room, and their homeserver must allow non-administrators to create groups (`enable_group_creation` in the Synapse configuration; off by default).
+
+ Fixed by [cb35df940a](https://github.com/matrix-org/synapse/commit/cb35df940a), [\#10723](https://github.com/matrix-org/synapse/issues/10723).
+
+Bugfixes
+--------
+
+- Fix a regression introduced in Synapse 1.41 which broke email transmission on systems using older versions of the Twisted library. ([\#10713](https://github.com/matrix-org/synapse/issues/10713))
+
+
Synapse 1.41.0 (2021-08-24)
===========================
diff --git a/debian/changelog b/debian/changelog
index 4da4bc018c..5f7a795b6e 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.41.1) stable; urgency=high
+
+ * New synapse release 1.41.1.
+
+ -- Synapse Packaging team Tue, 31 Aug 2021 12:59:10 +0100
+
matrix-synapse-py3 (1.41.0) stable; urgency=medium
* New synapse release 1.41.0.
diff --git a/synapse/__init__.py b/synapse/__init__.py
index ef3770262e..06d80f79b3 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.41.0"
+__version__ = "1.41.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
--
cgit 1.5.1