summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py15
-rw-r--r--synapse/storage/group_server.py152
-rw-r--r--synapse/storage/schema/delta/43/group_server.sql28
3 files changed, 195 insertions, 0 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index fdee9f1ad5..594566eb38 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -136,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "pushers", "id",
             extra_tables=[("deleted_pushers", "stream_id")],
         )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id",
+        )
 
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = StreamIdGenerator(
@@ -236,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=curr_state_delta_prefill,
         )
 
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn, "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index e8a799d8c7..036549d437 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -757,6 +757,103 @@ class GroupServerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
+    def register_user_group_membership(self, group_id, user_id, membership,
+                                       is_admin=False, content={},
+                                       local_attestation=None,
+                                       remote_attestation=None,
+                                       ):
+        def _register_user_group_membership_txn(txn, next_id):
+            # TODO: Upsert?
+            self._simple_delete_txn(
+                txn,
+                table="local_group_membership",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_insert_txn(
+                txn,
+                table="local_group_membership",
+                values={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "is_admin": is_admin,
+                    "membership": membership,
+                    "content": json.dumps(content),
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="local_group_updates",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                },
+            )
+            self._simple_insert_txn(
+                txn,
+                table="local_group_updates",
+                values={
+                    "stream_id": next_id,
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                    "content": json.dumps({"membership": membership, "content": content}),
+                }
+            )
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+            # TODO: Insert profile to ensuer it comes down stream if its a join.
+
+            if membership == "join":
+                if local_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_renewals",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": local_attestation["valid_until_ms"],
+                        }
+                    )
+                if remote_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_remote",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": remote_attestation["valid_until_ms"],
+                            "attestation": json.dumps(remote_attestation),
+                        }
+                    )
+            else:
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_renewals",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_remote",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+
+        with self._group_updates_id_gen.get_next() as next_id:
+            yield self.runInteraction(
+                "register_user_group_membership",
+                _register_user_group_membership_txn, next_id,
+            )
+
+    @defer.inlineCallbacks
     def create_group(self, group_id, user_id, name, avatar_url, short_description,
                      long_description,):
         yield self._simple_insert(
@@ -771,6 +868,61 @@ class GroupServerStore(SQLBaseStore):
             desc="create_group",
         )
 
+    def get_joined_groups(self, user_id):
+        return self._simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={
+                "user_id": user_id,
+                "membership": "join",
+            },
+            retcol="group_id",
+            desc="get_joined_groups",
+        )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token,))
+            return self.cursor_to_dict(txn)
+        return self.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn,
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token,))
+            return [{
+                "group_id": group_id,
+                "membership": membership,
+                "type": gtype,
+                "content": json.loads(content_json),
+            } for group_id, membership, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+        )
+
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
+
     def get_attestations_need_renewals(self, valid_until_ms):
         """Get all attestations that need to be renewed until givent time
         """
diff --git a/synapse/storage/schema/delta/43/group_server.sql b/synapse/storage/schema/delta/43/group_server.sql
index 472aab0a78..e32db8b313 100644
--- a/synapse/storage/schema/delta/43/group_server.sql
+++ b/synapse/storage/schema/delta/43/group_server.sql
@@ -142,3 +142,31 @@ CREATE TABLE group_attestations_remote (
 CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
 CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
 CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+CREATE TABLE local_group_membership (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    is_admin BOOLEAN NOT NULL,
+    membership TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+    stream_id BIGINT NOT NULL,
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+
+CREATE TABLE local_group_profiles (
+    group_id TEXT NOT NULL,
+    name TEXT,
+    avatar_url TEXT
+);