summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py64
-rw-r--r--synapse/rest/client/v2_alpha/sync.py5
-rw-r--r--synapse/storage/__init__.py15
-rw-r--r--synapse/storage/group_server.py68
-rw-r--r--synapse/storage/schema/delta/43/group_server.sql9
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/types.py2
7 files changed, 159 insertions, 6 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 91c6c6be3c..c01fcd3d59 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
         return True
 
 
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+    "join",
+    "invite",
+    "leave",
+])):
+    __slots__ = []
+
+    def __nonzero__(self):
+        return self.join or self.invite or self.leave
+
+
 class SyncResult(collections.namedtuple("SyncResult", [
     "next_batch",  # Token for the next sync
     "presence",  # List of presence events for the user.
@@ -119,6 +130,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "device_lists",  # List of user_ids whose devices have chanegd
     "device_one_time_keys_count",  # Dict of algorithm to count for one time keys
                                    # for this device
+    "groups",
 ])):
     __slots__ = []
 
@@ -134,7 +146,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
             self.archived or
             self.account_data or
             self.to_device or
-            self.device_lists
+            self.device_lists or
+            self.groups
         )
 
 
@@ -560,6 +573,8 @@ class SyncHandler(object):
                 user_id, device_id
             )
 
+        yield self._generate_sync_entry_for_groups(sync_result_builder)
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -568,10 +583,56 @@ class SyncHandler(object):
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
+            groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
             next_batch=sync_result_builder.now_token,
         ))
 
+    @measure_func("_generate_sync_entry_for_groups")
+    @defer.inlineCallbacks
+    def _generate_sync_entry_for_groups(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+        now_token = sync_result_builder.now_token
+
+        if since_token and since_token.groups_key:
+            results = yield self.store.get_groups_changes_for_user(
+                user_id, since_token.groups_key, now_token.groups_key,
+            )
+        else:
+            results = yield self.store.get_all_groups_for_user(
+                user_id, now_token.groups_key,
+            )
+
+        invited = {}
+        joined = {}
+        left = {}
+        for result in results:
+            membership = result["membership"]
+            group_id = result["group_id"]
+            gtype = result["type"]
+            content = result["content"]
+
+            if membership == "join":
+                if gtype == "membership":
+                    content.pop("membership", None)
+                    invited[group_id] = content["content"]
+                else:
+                    joined.setdefault(group_id, {})[gtype] = content
+            elif membership == "invite":
+                if gtype == "membership":
+                    content.pop("membership", None)
+                    invited[group_id] = content["content"]
+            else:
+                if gtype == "membership":
+                    left[group_id] = content["content"]
+
+        sync_result_builder.groups = GroupsSyncResult(
+            join=joined,
+            invite=invited,
+            leave=left,
+        )
+
     @measure_func("_generate_sync_entry_for_device_list")
     @defer.inlineCallbacks
     def _generate_sync_entry_for_device_list(self, sync_result_builder):
@@ -1260,6 +1321,7 @@ class SyncResultBuilder(object):
         self.invited = []
         self.archived = []
         self.device = []
+        self.groups = None
 
 
 class RoomSyncResultBuilder(object):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6dcc407451..5f208a4c1c 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -199,6 +199,11 @@ class SyncRestServlet(RestServlet):
                 "invite": invited,
                 "leave": archived,
             },
+            "groups": {
+                "join": sync_result.groups.join,
+                "invite": sync_result.groups.invite,
+                "leave": sync_result.groups.leave,
+            },
             "device_one_time_keys_count": sync_result.device_one_time_keys_count,
             "next_batch": sync_result.next_batch.to_string(),
         }
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index fdee9f1ad5..594566eb38 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -136,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "pushers", "id",
             extra_tables=[("deleted_pushers", "stream_id")],
         )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id",
+        )
 
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = StreamIdGenerator(
@@ -236,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=curr_state_delta_prefill,
         )
 
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn, "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index a2e7aa47d8..45f0a4c599 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -776,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
             remote_attestation (dict): If remote group then store the remote
                 attestation from the group, else None.
         """
-        def _register_user_group_membership_txn(txn):
+        def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
             self._simple_delete_txn(
                 txn,
@@ -798,6 +798,19 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
+            self._simple_insert_txn(
+                txn,
+                table="local_group_updates",
+                values={
+                    "stream_id": next_id,
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                    "content": json.dumps({"membership": membership, "content": content}),
+                }
+            )
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
             # TODO: Insert profile to ensure it comes down stream if its a join.
 
             if membership == "join":
@@ -840,10 +853,11 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
 
-        yield self.runInteraction(
-            "register_user_group_membership",
-            _register_user_group_membership_txn,
-        )
+        with self._group_updates_id_gen.get_next() as next_id:
+            yield self.runInteraction(
+                "register_user_group_membership",
+                _register_user_group_membership_txn, next_id,
+            )
 
     @defer.inlineCallbacks
     def create_group(self, group_id, user_id, name, avatar_url, short_description,
@@ -937,3 +951,47 @@ class GroupServerStore(SQLBaseStore):
             retcol="group_id",
             desc="get_joined_groups",
         )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token,))
+            return self.cursor_to_dict(txn)
+        return self.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn,
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token,))
+            return [{
+                "group_id": group_id,
+                "membership": membership,
+                "type": gtype,
+                "content": json.loads(content_json),
+            } for group_id, membership, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+        )
+
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/schema/delta/43/group_server.sql b/synapse/storage/schema/delta/43/group_server.sql
index e1fd47aa7f..92f3339c94 100644
--- a/synapse/storage/schema/delta/43/group_server.sql
+++ b/synapse/storage/schema/delta/43/group_server.sql
@@ -155,3 +155,12 @@ CREATE TABLE local_group_membership (
 
 CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
 CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+    stream_id BIGINT NOT NULL,
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    content TEXT NOT NULL
+);
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..e2be500815 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -45,6 +45,7 @@ class EventSources(object):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
+        groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -65,6 +66,7 @@ class EventSources(object):
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
+            groups_key=groups_key,
         )
         defer.returnValue(token)
 
diff --git a/synapse/types.py b/synapse/types.py
index b32c0e360d..37d5fa7f9f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -171,6 +171,7 @@ class StreamToken(
         "push_rules_key",
         "to_device_key",
         "device_list_key",
+        "groups_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -209,6 +210,7 @@ class StreamToken(
             or (int(other.push_rules_key) < int(self.push_rules_key))
             or (int(other.to_device_key) < int(self.to_device_key))
             or (int(other.device_list_key) < int(self.device_list_key))
+            or (int(other.groups_key) < int(self.groups_key))
         )
 
     def copy_and_advance(self, key, new_value):