summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/groups_local.py43
1 files changed, 25 insertions, 18 deletions
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 3df255b05a..e0f53120be 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -32,15 +32,17 @@ logger = logging.getLogger(__name__)
 # TODO: Add group memebership  /sync
 
 
-def _create_rerouter(name):
+def _create_rerouter(func_name):
+    """Returns a function that looks at the group id and calls the function
+    on federation or the local group server if the group is local
+    """
     def f(self, group_id, *args, **kwargs):
         if self.is_mine_id(group_id):
-            return getattr(self.groups_server_handler, name)(
+            return getattr(self.groups_server_handler, func_name)(
                 group_id, *args, **kwargs
             )
 
-        repl_layer = self.hs.get_replication_layer()
-        return getattr(repl_layer, name)(group_id, *args, **kwargs)
+        return getattr(self.transport_client, func_name)(group_id, *args, **kwargs)
     return f
 
 
@@ -50,6 +52,7 @@ class GroupsLocalHandler(object):
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
         self.groups_server_handler = hs.get_groups_server_handler()
+        self.transport_client = hs.get_federation_transport_client()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.keyring = hs.get_keyring()
@@ -82,15 +85,19 @@ class GroupsLocalHandler(object):
 
     @defer.inlineCallbacks
     def get_group_summary(self, group_id, requester_user_id):
+        """Get the group summary for a group.
+
+        If the group is remote we check that the users have valid attestations.
+        """
         if self.is_mine_id(group_id):
             res = yield self.groups_server_handler.get_group_summary(
                 group_id, requester_user_id
             )
             defer.returnValue(res)
 
-        repl_layer = self.hs.get_replication_layer()
-        res = yield repl_layer.get_group_summary(group_id, requester_user_id)
+        res = yield self.transport_client.get_group_summary(group_id, requester_user_id)
 
+        # Loop through the users and validate the attestations.
         chunk = res["users_section"]["users"]
         valid_users = []
         for entry in chunk:
@@ -121,8 +128,7 @@ class GroupsLocalHandler(object):
                 group_id, user_id, content
             )
 
-        repl_layer = self.hs.get_replication_layer()
-        return repl_layer.create_group(group_id, user_id, content)  # TODO
+        return self.transport_client.create_group(group_id, user_id, content)  # TODO
 
     def add_room(self, group_id, user_id, room_id, content):
         if self.is_mine_id(group_id):
@@ -130,8 +136,9 @@ class GroupsLocalHandler(object):
                 group_id, user_id, room_id, content
             )
 
-        repl_layer = self.hs.get_replication_layer()
-        return repl_layer.add_room_to_group(group_id, user_id, room_id, content)  # TODO
+        return self.transport_client.add_room_to_group(
+            group_id, user_id, room_id, content,
+        )  # TODO
 
     @defer.inlineCallbacks
     def get_users_in_group(self, group_id, requester_user_id):
@@ -141,8 +148,9 @@ class GroupsLocalHandler(object):
             )
             defer.returnValue(res)
 
-        repl_layer = self.hs.get_replication_layer()
-        res = yield repl_layer.get_users_in_group(group_id, requester_user_id)  # TODO
+        res = yield self.transport_client.get_users_in_group(
+            group_id, requester_user_id,
+        )  # TODO
 
         chunk = res["chunk"]
         valid_entries = []
@@ -179,8 +187,9 @@ class GroupsLocalHandler(object):
             local_attestation = self.attestations.create_attestation(group_id, user_id)
             content["attestation"] = local_attestation
 
-            repl_layer = self.hs.get_replication_layer()
-            res = yield repl_layer.accept_group_invite(group_id, user_id, content)
+            res = yield self.transport_client.accept_group_invite(
+                group_id, user_id, content,
+            )
 
             remote_attestation = res["attestation"]
 
@@ -211,8 +220,7 @@ class GroupsLocalHandler(object):
                 group_id, user_id, requester_user_id, content,
             )
         else:
-            repl_layer = self.hs.get_replication_layer()
-            res = yield repl_layer.invite_to_group(
+            res = yield self.transport_client.invite_to_group(
                 group_id, user_id, content,
             )
 
@@ -257,8 +265,7 @@ class GroupsLocalHandler(object):
             )
         else:
             content["requester_user_id"] = requester_user_id
-            repl_layer = self.hs.get_replication_layer()
-            res = yield repl_layer.remove_user_from_group(
+            res = yield self.transport_client.remove_user_from_group(
                 group_id, user_id, content
             )  # TODO