summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-04-23 15:46:29 +0100
committerGitHub <noreply@github.com>2021-04-23 15:46:29 +0100
commita15c003e5b0bff8bf78a675f3b719d3f25fe8bde (patch)
tree70b5d0d2e6745d84ad2b2d62963eaf49e65b1404 /synapse
parentRemove room and user invite ratelimits in default unit test config (#9871) (diff)
downloadsynapse-a15c003e5b0bff8bf78a675f3b719d3f25fe8bde.tar.xz
Make DomainSpecificString an attrs class (#9875)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/oidc.py5
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py9
-rw-r--r--synapse/types.py17
3 files changed, 23 insertions, 8 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 45514be50f..1c4a43be0a 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -957,6 +957,11 @@ class OidcProvider:
                 # and attempt to match it.
                 attributes = await oidc_response_to_user_attributes(failures=0)
 
+                if attributes.localpart is None:
+                    # If no localpart is returned then we will generate one, so
+                    # there is no need to search for existing users.
+                    return None
+
                 user_id = UserID(attributes.localpart, self._server_name).to_string()
                 users = await self._store.get_users_by_id_case_insensitive(user_id)
                 if users:
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index e5634f9679..488b97b32e 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -61,6 +61,15 @@ class NewUserConsentResource(DirectServeHtmlResource):
             self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
             return
 
+        # It should be impossible to get here without having first been through
+        # the pick-a-username step, which ensures chosen_localpart gets set.
+        if not session.chosen_localpart:
+            logger.warning("Session has no user name selected")
+            self._sso_handler.render_error(
+                request, "no_user", "No user name has been selected.", code=400
+            )
+            return
+
         user_id = UserID(session.chosen_localpart, self._server_name)
         user_profile = {
             "display_name": session.display_name,
diff --git a/synapse/types.py b/synapse/types.py
index e19f28d543..e52cd7ffd4 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -199,9 +199,8 @@ def get_localpart_from_id(string):
 DS = TypeVar("DS", bound="DomainSpecificString")
 
 
-class DomainSpecificString(
-    namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta
-):
+@attr.s(slots=True, frozen=True, repr=False)
+class DomainSpecificString(metaclass=abc.ABCMeta):
     """Common base class among ID/name strings that have a local part and a
     domain name, prefixed with a sigil.
 
@@ -213,11 +212,8 @@ class DomainSpecificString(
 
     SIGIL = abc.abstractproperty()  # type: str  # type: ignore
 
-    # Deny iteration because it will bite you if you try to create a singleton
-    # set by:
-    #    users = set(user)
-    def __iter__(self):
-        raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
+    localpart = attr.ib(type=str)
+    domain = attr.ib(type=str)
 
     # Because this class is a namedtuple of strings and booleans, it is deeply
     # immutable.
@@ -272,30 +268,35 @@ class DomainSpecificString(
     __repr__ = to_string
 
 
+@attr.s(slots=True, frozen=True, repr=False)
 class UserID(DomainSpecificString):
     """Structure representing a user ID."""
 
     SIGIL = "@"
 
 
+@attr.s(slots=True, frozen=True, repr=False)
 class RoomAlias(DomainSpecificString):
     """Structure representing a room name."""
 
     SIGIL = "#"
 
 
+@attr.s(slots=True, frozen=True, repr=False)
 class RoomID(DomainSpecificString):
     """Structure representing a room id. """
 
     SIGIL = "!"
 
 
+@attr.s(slots=True, frozen=True, repr=False)
 class EventID(DomainSpecificString):
     """Structure representing an event id. """
 
     SIGIL = "$"
 
 
+@attr.s(slots=True, frozen=True, repr=False)
 class GroupID(DomainSpecificString):
     """Structure representing a group ID."""