summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/errors.py5
-rw-r--r--synapse/handlers/profile.py11
-rw-r--r--synapse/handlers/room.py44
-rw-r--r--synapse/rest/client/v1/room.py68
-rw-r--r--synapse/types.py14
5 files changed, 73 insertions, 69 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b106fbed6d..0c7858f78d 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -84,6 +84,11 @@ class RegistrationError(SynapseError):
     pass
 
 
+class BadIdentifierError(SynapseError):
+    """An error indicating an identifier couldn't be parsed."""
+    pass
+
+
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 629e6e3594..32af622733 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -169,8 +169,15 @@ class ProfileHandler(BaseHandler):
             consumeErrors=True
         ).addErrback(unwrapFirstError)
 
-        state["displayname"] = displayname
-        state["avatar_url"] = avatar_url
+        if displayname is None:
+            del state["displayname"]
+        else:
+            state["displayname"] = displayname
+
+        if avatar_url is None:
+            del state["avatar_url"]
+        else:
+            state["avatar_url"] = avatar_url
 
         defer.returnValue(None)
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b2de2cd0c0..2950ed14e4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -527,7 +527,17 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue({"room_id": room_id})
 
     @defer.inlineCallbacks
-    def join_room_alias(self, joinee, room_alias, content={}):
+    def lookup_room_alias(self, room_alias):
+        """
+        Gets the room ID for an alias.
+
+        Args:
+            room_alias (str): The room alias to look up.
+        Returns:
+            A tuple of the room ID (str) and the hosts hosting the room ([str])
+        Raises:
+            SynapseError if the room couldn't be looked up.
+        """
         directory_handler = self.hs.get_handlers().directory_handler
         mapping = yield directory_handler.get_association(room_alias)
 
@@ -539,24 +549,40 @@ class RoomMemberHandler(BaseHandler):
         if not hosts:
             raise SynapseError(404, "No known servers")
 
-        # If event doesn't include a display name, add one.
-        yield collect_presencelike_data(self.distributor, joinee, content)
+        defer.returnValue((room_id, hosts))
+
+    @defer.inlineCallbacks
+    def do_join(self, requester, room_id, hosts=None):
+        """
+        Joins requester to room_id.
+
+        Args:
+            requester (Requester): The user joining the room.
+            room_id (str): The room ID (not alias) being joined.
+            hosts ([str]): A list of hosts which are hopefully in the room.
+        Raises:
+            SynapseError if the room couldn't be joined.
+        """
+        hosts = hosts or []
+
+        content = {"membership": Membership.JOIN}
+        if requester.is_guest:
+            content["kind"] = "guest"
+
+        yield collect_presencelike_data(self.distributor, requester.user, content)
 
-        content.update({"membership": Membership.JOIN})
         builder = self.event_builder_factory.new({
             "type": EventTypes.Member,
-            "state_key": joinee.to_string(),
+            "state_key": requester.user.to_string(),
             "room_id": room_id,
-            "sender": joinee.to_string(),
-            "membership": Membership.JOIN,
+            "sender": requester.user.to_string(),
+            "membership": Membership.JOIN,  # For backwards compatibility
             "content": content,
         })
         event, context = yield self._create_new_client_event(builder)
 
         yield self._do_join(event, context, room_hosts=hosts)
 
-        defer.returnValue({"room_id": room_id})
-
     @defer.inlineCallbacks
     def _do_join(self, event, context, room_hosts=None):
         room_id = event.room_id
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 81bfe377bd..1dd33b0a56 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -216,11 +216,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 
 # TODO: Needs unit testing for room ID + alias joins
 class JoinRoomAliasServlet(ClientV1RestServlet):
-
-    def register(self, http_server):
-        # /join/$room_identifier[/$txn_id]
-        PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
-        register_txn_path(self, PATTERNS, http_server)
+    PATTERNS = client_path_patterns("/join/(?P<room_identifier>[^/]*)$")
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_identifier, txn_id=None):
@@ -229,60 +225,22 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
             allow_guest=True,
         )
 
-        # the identifier could be a room alias or a room id. Try one then the
-        # other if it fails to parse, without swallowing other valid
-        # SynapseErrors.
+        handler = self.handlers.room_member_handler
 
-        identifier = None
-        is_room_alias = False
-        try:
-            identifier = RoomAlias.from_string(room_identifier)
-            is_room_alias = True
-        except SynapseError:
-            identifier = RoomID.from_string(room_identifier)
+        room_id = None
+        hosts = []
+        if RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, hosts = yield handler.lookup_room_alias(room_alias)
+        else:
+            room_id = RoomID.from_string(room_identifier).to_string()
 
         # TODO: Support for specifying the home server to join with?
 
-        if is_room_alias:
-            handler = self.handlers.room_member_handler
-            ret_dict = yield handler.join_room_alias(
-                requester.user,
-                identifier,
-            )
-            defer.returnValue((200, ret_dict))
-        else:  # room id
-            msg_handler = self.handlers.message_handler
-            content = {"membership": Membership.JOIN}
-            if requester.is_guest:
-                content["kind"] = "guest"
-            yield msg_handler.create_and_send_event(
-                {
-                    "type": EventTypes.Member,
-                    "content": content,
-                    "room_id": identifier.to_string(),
-                    "sender": requester.user.to_string(),
-                    "state_key": requester.user.to_string(),
-                },
-                token_id=requester.access_token_id,
-                txn_id=txn_id,
-                is_guest=requester.is_guest,
-            )
-
-            defer.returnValue((200, {"room_id": identifier.to_string()}))
-
-    @defer.inlineCallbacks
-    def on_PUT(self, request, room_identifier, txn_id):
-        try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
-        except KeyError:
-            pass
-
-        response = yield self.on_POST(request, room_identifier, txn_id)
-
-        self.txns.store_client_transaction(request, txn_id, response)
-        defer.returnValue(response)
+        yield handler.do_join(
+            requester, room_id, hosts=hosts
+        )
+        defer.returnValue((200, {"room_id": room_id}))
 
 
 # TODO: Needs unit testing
diff --git a/synapse/types.py b/synapse/types.py
index 2095837ba6..0be8384e18 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, BadIdentifierError
 
 from collections import namedtuple
 
@@ -51,13 +51,13 @@ class DomainSpecificString(
     def from_string(cls, s):
         """Parse the string given by 's' into a structure object."""
         if len(s) < 1 or s[0] != cls.SIGIL:
-            raise SynapseError(400, "Expected %s string to start with '%s'" % (
+            raise BadIdentifierError(400, "Expected %s string to start with '%s'" % (
                 cls.__name__, cls.SIGIL,
             ))
 
         parts = s[1:].split(':', 1)
         if len(parts) != 2:
-            raise SynapseError(
+            raise BadIdentifierError(
                 400, "Expected %s of the form '%slocalname:domain'" % (
                     cls.__name__, cls.SIGIL,
                 )
@@ -69,6 +69,14 @@ class DomainSpecificString(
         # names on one HS
         return cls(localpart=parts[0], domain=domain)
 
+    @classmethod
+    def is_valid(cls, s):
+        try:
+            cls.from_string(s)
+            return True
+        except:
+            return False
+
     def to_string(self):
         """Return a string encoding the fields of the structure object."""
         return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)