summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/directory.py13
-rw-r--r--synapse/rest/room.py6
-rw-r--r--synapse/server.py7
-rw-r--r--tests/test_types.py6
4 files changed, 18 insertions, 14 deletions
diff --git a/synapse/rest/directory.py b/synapse/rest/directory.py
index 31fd26e848..362f76c6ca 100644
--- a/synapse/rest/directory.py
+++ b/synapse/rest/directory.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.types import RoomAlias, RoomID
 from base import RestServlet, client_path_pattern
 
 import json
@@ -39,12 +38,11 @@ class ClientDirectoryServer(RestServlet):
         # TODO(erikj): Handle request
         local_only = "local_only" in request.args
 
-        room_alias = urllib.unquote(room_alias)
-        room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
+        room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
 
         dir_handler = self.handlers.directory_handler
         res = yield dir_handler.get_association(
-            room_alias_obj,
+            room_alias,
             local_only=local_only
         )
 
@@ -57,10 +55,9 @@ class ClientDirectoryServer(RestServlet):
 
         logger.debug("Got content: %s", content)
 
-        room_alias = urllib.unquote(room_alias)
-        room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
+        room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
 
-        logger.debug("Got room name: %s", room_alias_obj.to_string())
+        logger.debug("Got room name: %s", room_alias.to_string())
 
         room_id = content["room_id"]
         servers = content["servers"]
@@ -75,7 +72,7 @@ class ClientDirectoryServer(RestServlet):
 
         try:
             yield dir_handler.create_association(
-                room_alias_obj, room_id, servers
+                room_alias, room_id, servers
             )
         except:
             logger.exception("Failed to create association")
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
index 228bc9623d..1fc0c996b8 100644
--- a/synapse/rest/room.py
+++ b/synapse/rest/room.py
@@ -22,7 +22,6 @@ from synapse.api.events.room import (RoomTopicEvent, MessageEvent,
                                      RoomMemberEvent, FeedbackEvent)
 from synapse.api.constants import Feedback, Membership
 from synapse.api.streams import PaginationConfig
-from synapse.types import RoomAlias
 
 import json
 import logging
@@ -150,10 +149,7 @@ class JoinRoomAliasServlet(RestServlet):
 
         logger.debug("room_alias: %s", room_alias)
 
-        room_alias = RoomAlias.from_string(
-            urllib.unquote(room_alias),
-            self.hs
-        )
+        room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
 
         handler = self.handlers.room_member_handler
         ret_dict = yield handler.join_room_alias(user, room_alias)
diff --git a/synapse/server.py b/synapse/server.py
index 0211972d05..96830a88b1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -28,7 +28,7 @@ from synapse.handlers import Handlers
 from synapse.rest import RestServletFactory
 from synapse.state import StateHandler
 from synapse.storage import DataStore
-from synapse.types import UserID
+from synapse.types import UserID, RoomAlias
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.util.lockutils import LockManager
@@ -120,6 +120,11 @@ class BaseHomeServer(object):
         object."""
         return UserID.from_string(s, hs=self)
 
+    def parse_roomalias(self, s):
+        """Parse the string given by 's' as a Room Alias and return a RoomAlias
+        object."""
+        return RoomAlias.from_string(s, hs=self)
+
 # Build magic accessors for every dependency
 for depname in BaseHomeServer.DEPENDENCIES:
     BaseHomeServer._make_dependency_method(depname)
diff --git a/tests/test_types.py b/tests/test_types.py
index 522d52363d..d2ccbcfa55 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -62,3 +62,9 @@ class RoomAliasTestCase(unittest.TestCase):
         room = RoomAlias("channel", "my.domain", True)
 
         self.assertEquals(room.to_string(), "#channel:my.domain")
+
+    def test_via_homeserver(self):
+        room = mock_homeserver.parse_roomalias("#elsewhere:my.domain")
+
+        self.assertEquals("elsewhere", room.localpart)
+        self.assertEquals("my.domain", room.domain)