diff --git a/changelog.d/5589.feature b/changelog.d/5589.feature
new file mode 100644
index 0000000000..a87e669dd4
--- /dev/null
+++ b/changelog.d/5589.feature
@@ -0,0 +1 @@
+Add ability to pull all locally stored events out of synapse that a particular user can see.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 941ebfa107..e8a651e231 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -17,6 +17,10 @@ import logging
from twisted.internet import defer
+from synapse.api.constants import Membership
+from synapse.types import RoomStreamToken
+from synapse.visibility import filter_events_for_client
+
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -89,3 +93,182 @@ class AdminHandler(BaseHandler):
ret = yield self.store.search_users(term)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def export_user_data(self, user_id, writer):
+ """Write all data we have on the user to the given writer.
+
+ Args:
+ user_id (str)
+ writer (ExfiltrationWriter)
+
+ Returns:
+ defer.Deferred: Resolves when all data for a user has been written.
+ The returned value is that returned by `writer.finished()`.
+ """
+ # Get all rooms the user is in or has been in
+ rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id,
+ membership_list=(
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
+ Membership.INVITE,
+ ),
+ )
+
+ # We only try and fetch events for rooms the user has been in. If
+ # they've been e.g. invited to a room without joining then we handle
+ # those seperately.
+ rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
+
+ for index, room in enumerate(rooms):
+ room_id = room.room_id
+
+ logger.info(
+ "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
+ )
+
+ forgotten = yield self.store.did_forget(user_id, room_id)
+ if forgotten:
+ logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
+ continue
+
+ if room_id not in rooms_user_has_been_in:
+ # If we haven't been in the rooms then the filtering code below
+ # won't return anything, so we need to handle these cases
+ # explicitly.
+
+ if room.membership == Membership.INVITE:
+ event_id = room.event_id
+ invite = yield self.store.get_event(event_id, allow_none=True)
+ if invite:
+ invited_state = invite.unsigned["invite_room_state"]
+ writer.write_invite(room_id, invite, invited_state)
+
+ continue
+
+ # We only want to bother fetching events up to the last time they
+ # were joined. We estimate that point by looking at the
+ # stream_ordering of the last membership if it wasn't a join.
+ if room.membership == Membership.JOIN:
+ stream_ordering = yield self.store.get_room_max_stream_ordering()
+ else:
+ stream_ordering = room.stream_ordering
+
+ from_key = str(RoomStreamToken(0, 0))
+ to_key = str(RoomStreamToken(None, stream_ordering))
+
+ written_events = set() # Events that we've processed in this room
+
+ # We need to track gaps in the events stream so that we can then
+ # write out the state at those events. We do this by keeping track
+ # of events whose prev events we haven't seen.
+
+ # Map from event ID to prev events that haven't been processed,
+ # dict[str, set[str]].
+ event_to_unseen_prevs = {}
+
+ # The reverse mapping to above, i.e. map from unseen event to events
+ # that have the unseen event in their prev_events, i.e. the unseen
+ # events "children". dict[str, set[str]]
+ unseen_to_child_events = {}
+
+ # We fetch events in the room the user could see by fetching *all*
+ # events that we have and then filtering, this isn't the most
+ # efficient method perhaps but it does guarantee we get everything.
+ while True:
+ events, _ = yield self.store.paginate_room_events(
+ room_id, from_key, to_key, limit=100, direction="f"
+ )
+ if not events:
+ break
+
+ from_key = events[-1].internal_metadata.after
+
+ events = yield filter_events_for_client(self.store, user_id, events)
+
+ writer.write_events(room_id, events)
+
+ # Update the extremity tracking dicts
+ for event in events:
+ # Check if we have any prev events that haven't been
+ # processed yet, and add those to the appropriate dicts.
+ unseen_events = set(event.prev_event_ids()) - written_events
+ if unseen_events:
+ event_to_unseen_prevs[event.event_id] = unseen_events
+ for unseen in unseen_events:
+ unseen_to_child_events.setdefault(unseen, set()).add(
+ event.event_id
+ )
+
+ # Now check if this event is an unseen prev event, if so
+ # then we remove this event from the appropriate dicts.
+ for child_id in unseen_to_child_events.pop(event.event_id, []):
+ event_to_unseen_prevs[child_id].discard(event.event_id)
+
+ written_events.add(event.event_id)
+
+ logger.info(
+ "Written %d events in room %s", len(written_events), room_id
+ )
+
+ # Extremities are the events who have at least one unseen prev event.
+ extremities = (
+ event_id
+ for event_id, unseen_prevs in event_to_unseen_prevs.items()
+ if unseen_prevs
+ )
+ for event_id in extremities:
+ if not event_to_unseen_prevs[event_id]:
+ continue
+ state = yield self.store.get_state_for_event(event_id)
+ writer.write_state(room_id, event_id, state)
+
+ defer.returnValue(writer.finished())
+
+
+class ExfiltrationWriter(object):
+ """Interface used to specify how to write exported data.
+ """
+
+ def write_events(self, room_id, events):
+ """Write a batch of events for a room.
+
+ Args:
+ room_id (str)
+ events (list[FrozenEvent])
+ """
+ pass
+
+ def write_state(self, room_id, event_id, state):
+ """Write the state at the given event in the room.
+
+ This only gets called for backward extremities rather than for each
+ event.
+
+ Args:
+ room_id (str)
+ event_id (str)
+ state (dict[tuple[str, str], FrozenEvent])
+ """
+ pass
+
+ def write_invite(self, room_id, event, state):
+ """Write an invite for the room, with associated invite state.
+
+ Args:
+ room_id (str)
+ event (FrozenEvent)
+ state (dict[tuple[str, str], dict]): A subset of the state at the
+ invite, with a subset of the event keys (type, state_key
+ content and sender)
+ """
+
+ def finished(self):
+ """Called when all data has succesfully been exported and written.
+
+ This functions return value is passed to the caller of
+ `export_user_data`.
+ """
+ pass
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8004aeb909..32cfd010a5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -575,6 +575,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
+ @defer.inlineCallbacks
+ def get_rooms_user_has_been_in(self, user_id):
+ """Get all rooms that the user has ever been in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]: Set of room IDs.
+ """
+
+ room_ids = yield self._simple_select_onecol(
+ table="room_memberships",
+ keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+ retcol="room_id",
+ desc="get_rooms_user_has_been_in",
+ )
+
+ return set(room_ids)
+
class RoomMemberStore(RoomMemberWorkerStore):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 386a9dbe14..a0465484df 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -833,7 +833,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
- of the result set.
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between
+ `from_token` and `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -905,15 +907,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return. Zero or less
- means no limit.
+ limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
- tuple[list[dict], str]: Returns the results as a list of dicts and
- a token that points to the end of the result set. The dicts have
- the keys "event_id", "topological_ordering" and "stream_orderign".
+ tuple[list[FrozenEvent], str]: Returns the results as a list of
+ events and a token that points to the end of the result set. If no
+ events are returned then the end of the stream has been reached
+ (i.e. there are no events between `from_key` and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
new file mode 100644
index 0000000000..fc37c4328c
--- /dev/null
+++ b/tests/handlers/test_admin.py
@@ -0,0 +1,210 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import Counter
+
+from mock import Mock
+
+import synapse.api.errors
+import synapse.handlers.admin
+import synapse.rest.admin
+import synapse.storage
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class ExfiltrateData(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ self.user1 = self.register_user("user1", "password")
+ self.token1 = self.login("user1", "password")
+
+ self.user2 = self.register_user("user2", "password")
+ self.token2 = self.login("user2", "password")
+
+ def test_single_public_joined_room(self):
+ """Test that we write *all* events for a public room
+ """
+ room_id = self.helper.create_room_as(
+ self.user1, tok=self.token1, is_public=True
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can see all events there shouldn't be any extremities, so no
+ # state should be written
+ writer.write_state.assert_not_called()
+
+ # Collect all events that were written
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
+
+ def test_single_private_joined_room(self):
+ """Tests that we correctly write state when we can't see all events in
+ a room.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "joined"},
+ tok=self.token1,
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can't see all events there should be one extremity.
+ writer.write_state.assert_called_once()
+
+ # Collect all events that were written
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
+
+ def test_single_left_room(self):
+ """Tests that we don't see events in the room after we leave.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+ self.helper.leave(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can see all events there shouldn't be any extremities, so no
+ # state should be written
+ writer.write_state.assert_not_called()
+
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
+
+ def test_single_left_rejoined_private_room(self):
+ """Tests that see the correct events in private rooms when we
+ repeatedly join and leave.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "joined"},
+ tok=self.token1,
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+ self.helper.leave(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called_once()
+
+ # Since we joined/left/joined again we expect there to be two gaps.
+ self.assertEqual(writer.write_state.call_count, 2)
+
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
+
+ def test_invite(self):
+ """Tests that pending invites get handled correctly.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_state.assert_not_called()
+ writer.write_invite.assert_called_once()
+
+ args = writer.write_invite.call_args[0]
+ self.assertEqual(args[0], room_id)
+ self.assertEqual(args[1].content["membership"], "invite")
+ self.assertTrue(args[2]) # Assert there is at least one bit of state
diff --git a/tests/unittest.py b/tests/unittest.py
index 0f0c2ad69d..cabe787cb4 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -472,7 +472,7 @@ class HomeserverTestCase(TestCase):
"POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
)
self.render(request)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
return user_id
|