diff --git a/changelog.d/9875.misc b/changelog.d/9875.misc
new file mode 100644
index 0000000000..9345c0bf45
--- /dev/null
+++ b/changelog.d/9875.misc
@@ -0,0 +1 @@
+Make `DomainSpecificString` an `attrs` class.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 45514be50f..1c4a43be0a 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -957,6 +957,11 @@ class OidcProvider:
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
+ if attributes.localpart is None:
+ # If no localpart is returned then we will generate one, so
+ # there is no need to search for existing users.
+ return None
+
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index e5634f9679..488b97b32e 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -61,6 +61,15 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return
+ # It should be impossible to get here without having first been through
+ # the pick-a-username step, which ensures chosen_localpart gets set.
+ if not session.chosen_localpart:
+ logger.warning("Session has no user name selected")
+ self._sso_handler.render_error(
+ request, "no_user", "No user name has been selected.", code=400
+ )
+ return
+
user_id = UserID(session.chosen_localpart, self._server_name)
user_profile = {
"display_name": session.display_name,
diff --git a/synapse/types.py b/synapse/types.py
index e19f28d543..e52cd7ffd4 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -199,9 +199,8 @@ def get_localpart_from_id(string):
DS = TypeVar("DS", bound="DomainSpecificString")
-class DomainSpecificString(
- namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta
-):
+@attr.s(slots=True, frozen=True, repr=False)
+class DomainSpecificString(metaclass=abc.ABCMeta):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -213,11 +212,8 @@ class DomainSpecificString(
SIGIL = abc.abstractproperty() # type: str # type: ignore
- # Deny iteration because it will bite you if you try to create a singleton
- # set by:
- # users = set(user)
- def __iter__(self):
- raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
+ localpart = attr.ib(type=str)
+ domain = attr.ib(type=str)
# Because this class is a namedtuple of strings and booleans, it is deeply
# immutable.
@@ -272,30 +268,35 @@ class DomainSpecificString(
__repr__ = to_string
+@attr.s(slots=True, frozen=True, repr=False)
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
SIGIL = "@"
+@attr.s(slots=True, frozen=True, repr=False)
class RoomAlias(DomainSpecificString):
"""Structure representing a room name."""
SIGIL = "#"
+@attr.s(slots=True, frozen=True, repr=False)
class RoomID(DomainSpecificString):
"""Structure representing a room id. """
SIGIL = "!"
+@attr.s(slots=True, frozen=True, repr=False)
class EventID(DomainSpecificString):
"""Structure representing an event id. """
SIGIL = "$"
+@attr.s(slots=True, frozen=True, repr=False)
class GroupID(DomainSpecificString):
"""Structure representing a group ID."""
|