summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2017-07-21 11:05:39 +0100
committerGitHub <noreply@github.com>2017-07-21 11:05:39 +0100
commit96917d555240f52c98d698f95e2ff1e65c8b0b5d (patch)
treedf70bdc99fc2db0a229a42a55b1de2efda99e883
parentMerge pull request #2377 from matrix-org/erikj/group_profile_update (diff)
parentAdd notifier (diff)
downloadsynapse-96917d555240f52c98d698f95e2ff1e65c8b0b5d.tar.xz
Merge pull request #2378 from matrix-org/erikj/group_sync_support
Add groups to sync stream
Diffstat (limited to '')
-rw-r--r--synapse/app/synchrotron.py6
-rw-r--r--synapse/handlers/groups_local.py21
-rw-r--r--synapse/handlers/sync.py64
-rw-r--r--synapse/replication/slave/storage/groups.py54
-rw-r--r--synapse/replication/tcp/streams.py20
-rw-r--r--synapse/rest/client/v2_alpha/sync.py5
-rw-r--r--synapse/storage/__init__.py15
-rw-r--r--synapse/storage/group_server.py91
-rw-r--r--synapse/storage/schema/delta/43/group_server.sql9
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py4
12 files changed, 283 insertions, 12 deletions
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 4bdd99a966..d06a05acd9 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -41,6 +41,7 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
@@ -75,6 +76,7 @@ class SynchrotronSlavedStore(
     SlavedRegistrationStore,
     SlavedFilteringStore,
     SlavedPresenceStore,
+    SlavedGroupServerStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedClientIpStore,
@@ -409,6 +411,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
             )
         elif stream_name == "presence":
             yield self.presence_handler.process_replication_rows(token, rows)
+        elif stream_name == "receipts":
+            self.notifier.on_new_event(
+                "groups_key", token, users=[row.user_id for row in rows],
+            )
 
 
 def start(config_options):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index b2c920da38..d0ed988224 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -63,6 +63,7 @@ class GroupsLocalHandler(object):
         self.is_mine_id = hs.is_mine_id
         self.signing_key = hs.config.signing_key[0]
         self.server_name = hs.hostname
+        self.notifier = hs.get_notifier()
         self.attestations = hs.get_groups_attestation_signing()
 
         # Ensure attestations get renewed
@@ -212,13 +213,16 @@ class GroupsLocalHandler(object):
                 user_id=user_id,
             )
 
-        yield self.store.register_user_group_membership(
+        token = yield self.store.register_user_group_membership(
             group_id, user_id,
             membership="join",
             is_admin=False,
             local_attestation=local_attestation,
             remote_attestation=remote_attestation,
         )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
 
         defer.returnValue({})
 
@@ -258,11 +262,14 @@ class GroupsLocalHandler(object):
             if "avatar_url" in content["profile"]:
                 local_profile["avatar_url"] = content["profile"]["avatar_url"]
 
-        yield self.store.register_user_group_membership(
+        token = yield self.store.register_user_group_membership(
             group_id, user_id,
             membership="invite",
             content={"profile": local_profile, "inviter": content["inviter"]},
         )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
 
         defer.returnValue({"state": "invite"})
 
@@ -271,10 +278,13 @@ class GroupsLocalHandler(object):
         """Remove a user from a group
         """
         if user_id == requester_user_id:
-            yield self.store.register_user_group_membership(
+            token = yield self.store.register_user_group_membership(
                 group_id, user_id,
                 membership="leave",
             )
+            self.notifier.on_new_event(
+                "groups_key", token, users=[user_id],
+            )
 
             # TODO: Should probably remember that we tried to leave so that we can
             # retry if the group server is currently down.
@@ -297,10 +307,13 @@ class GroupsLocalHandler(object):
         """One of our users was removed/kicked from a group
         """
         # TODO: Check if user in group
-        yield self.store.register_user_group_membership(
+        token = yield self.store.register_user_group_membership(
             group_id, user_id,
             membership="leave",
         )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
 
     @defer.inlineCallbacks
     def get_joined_groups(self, user_id):
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 91c6c6be3c..600d0589fd 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
         return True
 
 
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+    "join",
+    "invite",
+    "leave",
+])):
+    __slots__ = []
+
+    def __nonzero__(self):
+        return bool(self.join or self.invite or self.leave)
+
+
 class SyncResult(collections.namedtuple("SyncResult", [
     "next_batch",  # Token for the next sync
     "presence",  # List of presence events for the user.
@@ -119,6 +130,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "device_lists",  # List of user_ids whose devices have chanegd
     "device_one_time_keys_count",  # Dict of algorithm to count for one time keys
                                    # for this device
+    "groups",
 ])):
     __slots__ = []
 
@@ -134,7 +146,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
             self.archived or
             self.account_data or
             self.to_device or
-            self.device_lists
+            self.device_lists or
+            self.groups
         )
 
 
@@ -560,6 +573,8 @@ class SyncHandler(object):
                 user_id, device_id
             )
 
+        yield self._generate_sync_entry_for_groups(sync_result_builder)
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -568,10 +583,56 @@ class SyncHandler(object):
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
+            groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
             next_batch=sync_result_builder.now_token,
         ))
 
+    @measure_func("_generate_sync_entry_for_groups")
+    @defer.inlineCallbacks
+    def _generate_sync_entry_for_groups(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+        now_token = sync_result_builder.now_token
+
+        if since_token and since_token.groups_key:
+            results = yield self.store.get_groups_changes_for_user(
+                user_id, since_token.groups_key, now_token.groups_key,
+            )
+        else:
+            results = yield self.store.get_all_groups_for_user(
+                user_id, now_token.groups_key,
+            )
+
+        invited = {}
+        joined = {}
+        left = {}
+        for result in results:
+            membership = result["membership"]
+            group_id = result["group_id"]
+            gtype = result["type"]
+            content = result["content"]
+
+            if membership == "join":
+                if gtype == "membership":
+                    content.pop("membership", None)
+                    invited[group_id] = content["content"]
+                else:
+                    joined.setdefault(group_id, {})[gtype] = content
+            elif membership == "invite":
+                if gtype == "membership":
+                    content.pop("membership", None)
+                    invited[group_id] = content["content"]
+            else:
+                if gtype == "membership":
+                    left[group_id] = content["content"]
+
+        sync_result_builder.groups = GroupsSyncResult(
+            join=joined,
+            invite=invited,
+            leave=left,
+        )
+
     @measure_func("_generate_sync_entry_for_device_list")
     @defer.inlineCallbacks
     def _generate_sync_entry_for_device_list(self, sync_result_builder):
@@ -1260,6 +1321,7 @@ class SyncResultBuilder(object):
         self.invited = []
         self.archived = []
         self.device = []
+        self.groups = None
 
 
 class RoomSyncResultBuilder(object):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
new file mode 100644
index 0000000000..0bc4bce5b0
--- /dev/null
+++ b/synapse/replication/slave/storage/groups.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedGroupServerStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+
+        self.hs = hs
+
+        self._group_updates_id_gen = SlavedIdTracker(
+            db_conn, "local_group_updates", "stream_id",
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+        )
+
+    get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
+    get_group_stream_token = DataStore.get_group_stream_token.__func__
+    get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
+
+    def stream_positions(self):
+        result = super(SlavedGroupServerStore, self).stream_positions()
+        result["groups"] = self._group_updates_id_gen.get_current_token()
+        return result
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "groups":
+            self._group_updates_id_gen.advance(token)
+            for row in rows:
+                self._group_updates_stream_cache.entity_has_changed(
+                    row.user_id, token
+                )
+
+        return super(SlavedGroupServerStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index fbafe12cc2..4c60bf79f9 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
     "state_key",  # str
     "event_id",  # str, optional
 ))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+    "group_id",  # str
+    "user_id",  # str
+    "type",  # str
+    "content",  # dict
+))
 
 
 class Stream(object):
@@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
         super(CurrentStateDeltaStream, self).__init__(hs)
 
 
+class GroupServerStream(Stream):
+    NAME = "groups"
+    ROW_TYPE = GroupsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_group_stream_token
+        self.update_function = store.get_all_groups_changes
+
+        super(GroupServerStream, self).__init__(hs)
+
+
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
@@ -482,5 +501,6 @@ STREAMS_MAP = {
         TagAccountDataStream,
         AccountDataStream,
         CurrentStateDeltaStream,
+        GroupServerStream,
     )
 }
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6dcc407451..5f208a4c1c 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -199,6 +199,11 @@ class SyncRestServlet(RestServlet):
                 "invite": invited,
                 "leave": archived,
             },
+            "groups": {
+                "join": sync_result.groups.join,
+                "invite": sync_result.groups.invite,
+                "leave": sync_result.groups.leave,
+            },
             "device_one_time_keys_count": sync_result.device_one_time_keys_count,
             "next_batch": sync_result.next_batch.to_string(),
         }
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index fdee9f1ad5..594566eb38 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -136,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "pushers", "id",
             extra_tables=[("deleted_pushers", "stream_id")],
         )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id",
+        )
 
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = StreamIdGenerator(
@@ -236,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=curr_state_delta_prefill,
         )
 
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn, "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index ce0f513e30..d42e215b26 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -776,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
             remote_attestation (dict): If remote group then store the remote
                 attestation from the group, else None.
         """
-        def _register_user_group_membership_txn(txn):
+        def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
             self._simple_delete_txn(
                 txn,
@@ -798,6 +798,19 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
+            self._simple_insert_txn(
+                txn,
+                table="local_group_updates",
+                values={
+                    "stream_id": next_id,
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                    "content": json.dumps({"membership": membership, "content": content}),
+                }
+            )
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
             # TODO: Insert profile to ensure it comes down stream if its a join.
 
             if membership == "join":
@@ -840,10 +853,13 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
 
-        yield self.runInteraction(
-            "register_user_group_membership",
-            _register_user_group_membership_txn,
-        )
+            return next_id
+
+        with self._group_updates_id_gen.get_next() as next_id:
+            yield self.runInteraction(
+                "register_user_group_membership",
+                _register_user_group_membership_txn, next_id,
+            )
 
     @defer.inlineCallbacks
     def create_group(self, group_id, user_id, name, avatar_url, short_description,
@@ -948,3 +964,68 @@ class GroupServerStore(SQLBaseStore):
             retcol="group_id",
             desc="get_joined_groups",
         )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token,))
+            return self.cursor_to_dict(txn)
+        return self.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn,
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token,))
+            return [{
+                "group_id": group_id,
+                "membership": membership,
+                "type": gtype,
+                "content": json.loads(content_json),
+            } for group_id, membership, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+        )
+
+    def get_all_groups_changes(self, from_token, to_token, limit):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+            from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_all_groups_changes_txn(txn):
+            sql = """
+                SELECT stream_id, group_id, user_id, type, content
+                FROM local_group_updates
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit,))
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_groups_changes", _get_all_groups_changes_txn,
+        )
+
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/schema/delta/43/group_server.sql b/synapse/storage/schema/delta/43/group_server.sql
index e1fd47aa7f..92f3339c94 100644
--- a/synapse/storage/schema/delta/43/group_server.sql
+++ b/synapse/storage/schema/delta/43/group_server.sql
@@ -155,3 +155,12 @@ CREATE TABLE local_group_membership (
 
 CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
 CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+    stream_id BIGINT NOT NULL,
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    content TEXT NOT NULL
+);
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..f03ad99118 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -45,6 +45,7 @@ class EventSources(object):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
+        groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -65,6 +66,7 @@ class EventSources(object):
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
+            groups_key=groups_key,
         )
         defer.returnValue(token)
 
@@ -73,6 +75,7 @@ class EventSources(object):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
+        groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -93,5 +96,6 @@ class EventSources(object):
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
+            groups_key=groups_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index b32c0e360d..37d5fa7f9f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -171,6 +171,7 @@ class StreamToken(
         "push_rules_key",
         "to_device_key",
         "device_list_key",
+        "groups_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -209,6 +210,7 @@ class StreamToken(
             or (int(other.push_rules_key) < int(self.push_rules_key))
             or (int(other.to_device_key) < int(self.to_device_key))
             or (int(other.device_list_key) < int(self.device_list_key))
+            or (int(other.groups_key) < int(self.groups_key))
         )
 
     def copy_and_advance(self, key, new_value):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d746ea8568..de376fb514 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_topo_token_is_accepted(self):
-        token = "t1-0_0_0_0_0_0_0_0"
+        token = "t1-0_0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
-        token = "s0_0_0_0_0_0_0_0"
+        token = "s0_0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))