From 31a9eceda5cf00b0482baf1c8bf1e138c823f621 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 30 Mar 2016 15:58:20 +0100 Subject: Add a replication stream for state groups --- tests/replication/test_resource.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) (limited to 'tests/replication') diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index f4b5fb3328..b1dd7b4a74 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -58,15 +58,21 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events(self): - get = self.get(events="-1", timeout="0") + def test_events_and_state(self): + get = self.get(events="-1", state="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( Requester(self.user, "", False), {} ) code, body = yield get self.assertEquals(code, 200) self.assertEquals(body["events"]["field_names"], [ - "position", "internal", "json" + "position", "internal", "json", "state_group" + ]) + self.assertEquals(body["state_groups"]["field_names"], [ + "position", "room_id", "event_id" + ]) + self.assertEquals(body["state_group_state"]["field_names"], [ + "position", "type", "state_key", "event_id" ]) @defer.inlineCallbacks @@ -132,6 +138,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_backfill = _test_timeout("backfill") test_timeout_push_rules = _test_timeout("push_rules") test_timeout_pushers = _test_timeout("pushers") + test_timeout_state = _test_timeout("state") @defer.inlineCallbacks def send_text_message(self, room_id, message): @@ -182,4 +189,21 @@ class ReplicationResourceCase(unittest.TestCase): ) response_body = json.loads(response_json) + if response_code == 200: + self.check_response(response_body) + defer.returnValue((response_code, response_body)) + + def check_response(self, response_body): + for name, stream in response_body.items(): + self.assertIn("field_names", stream) + field_names = stream["field_names"] + self.assertIn("rows", stream) + self.assertTrue(stream["rows"]) + for row in stream["rows"]: + self.assertEquals( + len(row), len(field_names), + "%s: len(row = %r) == len(field_names = %r)" % ( + name, row, field_names + ) + ) -- cgit 1.5.1 From 75fb9ac1be0fada60cdde38153ac0e3fe2b19a0a Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 14:12:51 +0100 Subject: Add a slaved events store class Add a test to check that get_room_names_and_aliases does the same thing on both the master and on the slave data store. --- synapse/replication/slave/__init__.py | 14 ++ synapse/replication/slave/storage/__init__.py | 14 ++ synapse/replication/slave/storage/_base.py | 28 +++ .../slave/storage/_slaved_id_tracker.py | 30 ++++ synapse/replication/slave/storage/events.py | 198 +++++++++++++++++++++ synapse/storage/events.py | 4 +- tests/replication/slave/__init__.py | 14 ++ tests/replication/slave/storage/__init__.py | 14 ++ tests/replication/slave/storage/_base.py | 57 ++++++ tests/replication/slave/storage/test_events.py | 114 ++++++++++++ 10 files changed, 485 insertions(+), 2 deletions(-) create mode 100644 synapse/replication/slave/__init__.py create mode 100644 synapse/replication/slave/storage/__init__.py create mode 100644 synapse/replication/slave/storage/_base.py create mode 100644 synapse/replication/slave/storage/_slaved_id_tracker.py create mode 100644 synapse/replication/slave/storage/events.py create mode 100644 tests/replication/slave/__init__.py create mode 100644 tests/replication/slave/storage/__init__.py create mode 100644 tests/replication/slave/storage/_base.py create mode 100644 tests/replication/slave/storage/test_events.py (limited to 'tests/replication') diff --git a/synapse/replication/slave/__init__.py b/synapse/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/__init__.py b/synapse/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py new file mode 100644 index 0000000000..46e43ce1c7 --- /dev/null +++ b/synapse/replication/slave/storage/_base.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage._base import SQLBaseStore +from twisted.internet import defer + + +class BaseSlavedStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(BaseSlavedStore, self).__init__(hs) + + def stream_positions(self): + return {} + + def process_replication(self, result): + return defer.succeed(None) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py new file mode 100644 index 0000000000..24b5c79d4a --- /dev/null +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.util.id_generators import _load_current_id + + +class SlavedIdTracker(object): + def __init__(self, db_conn, table, column, extra_tables=[], step=1): + self.step = step + self._current = _load_current_id(db_conn, table, column, step) + for table, column in extra_tables: + self.advance(_load_current_id(db_conn, table, column)) + + def advance(self, new_id): + self._current = (max if self.step > 0 else min)(self._current, new_id) + + def get_current_token(self): + return self._current diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py new file mode 100644 index 0000000000..68b924e37b --- /dev/null +++ b/synapse/replication/slave/storage/events.py @@ -0,0 +1,198 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.api.constants import EventTypes +from synapse.events import FrozenEvent +from synapse.storage import DataStore +from synapse.storage.room import RoomStore +from synapse.storage.roommember import RoomMemberStore +from synapse.storage.event_federation import EventFederationStore +from synapse.storage.state import StateStore +from synapse.util.caches.stream_change_cache import StreamChangeCache + +import ujson as json + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedEventStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedEventStore, self).__init__(db_conn, hs) + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + events_max = self._stream_id_gen.get_current_token() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + + # Cached functions can't be accessed through a class instance so we need + # to reach inside the __dict__ to extract them. + get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] + get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] + get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] + get_latest_event_ids_in_room = EventFederationStore.__dict__[ + "get_latest_event_ids_in_room" + ] + _get_current_state_for_key = StateStore.__dict__[ + "_get_current_state_for_key" + ] + + get_current_state = DataStore.get_current_state.__func__ + get_current_state_for_key = DataStore.get_current_state_for_key.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + get_rooms_for_user_where_membership_is = ( + DataStore.get_rooms_for_user_where_membership_is.__func__ + ) + get_membership_changes_for_user = ( + DataStore.get_membership_changes_for_user.__func__ + ) + get_room_events_max_id = DataStore.get_room_events_max_id.__func__ + get_room_events_stream_for_room = ( + DataStore.get_room_events_stream_for_room.__func__ + ) + _set_before_and_after = DataStore._set_before_and_after + + _get_events = DataStore._get_events.__func__ + _get_events_from_cache = DataStore._get_events_from_cache.__func__ + + _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ + _parse_events_txn = DataStore._parse_events_txn.__func__ + _get_events_txn = DataStore._get_events_txn.__func__ + _fetch_events_txn = DataStore._fetch_events_txn.__func__ + _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + + def stream_positions(self): + result = super(SlavedEventStore, self).stream_positions() + result["events"] = self._stream_id_gen.get_current_token() + result["backfilled"] = self._backfill_id_gen.get_current_token() + return result + + def process_replication(self, result): + state_resets = set( + r[0] for r in result.get("state_resets", {"rows": []})["rows"] + ) + + stream = result.get("events") + if stream: + self._stream_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=False, state_resets=state_resets + ) + + stream = result.get("backfill") + if stream: + self._backfill_id_gen.advance(stream["position"]) + for row in stream["rows"]: + self._process_replication_row( + row, backfilled=True, state_resets=state_resets + ) + + stream = result.get("forward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + stream = result.get("backward_ex_outliers") + if stream: + for row in stream["rows"]: + event_id = row[1] + self._invalidate_get_event_cache(event_id) + + return super(SlavedEventStore, self).process_replication(result) + + def _process_replication_row(self, row, backfilled, state_resets): + position = row[0] + internal = json.loads(row[1]) + event_json = json.loads(row[2]) + + event = FrozenEvent(event_json, internal_metadata_dict=internal) + self._invalidate_caches_for_event( + event, backfilled, reset_state=position in state_resets + ) + + def _invalidate_caches_for_event(self, event, backfilled, reset_state): + if reset_state: + self._get_current_state_for_key.invalidate_all() + self.get_rooms_for_user.invalidate_all() + self.get_users_in_room.invalidate((event.room_id,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_room_name_and_aliases.invalidate((event.room_id,)) + + self._invalidate_get_event_cache(event.event_id) + + if not backfilled: + self._events_stream_cache.entity_has_changed( + event.room_id, event.internal_metadata.stream_ordering + ) + + # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + # (event.room_id,) + # ) + + if event.type == EventTypes.Redaction: + self._invalidate_get_event_cache(event.redacts) + + if event.type == EventTypes.Member: + self.get_rooms_for_user.invalidate((event.state_key,)) + # self.get_joined_hosts_for_room.invalidate((event.room_id,)) + self.get_users_in_room.invalidate((event.room_id,)) + # self._membership_stream_cache.entity_has_changed( + # event.state_key, event.internal_metadata.stream_ordering + # ) + + if not event.is_state(): + return + + if backfilled: + return + + if (not event.internal_metadata.is_invite_from_remote() + and event.internal_metadata.is_outlier()): + return + + self._get_current_state_for_key.invalidate(( + event.room_id, event.type, event.state_key + )) + + if event.type in [EventTypes.Name, EventTypes.Aliases]: + self.get_room_name_and_aliases.invalidate( + (event.room_id,) + ) + pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d299a1132..ee87a71719 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1134,7 +1134,7 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT -event_stream_ordering FROM current_state_resets" + "SELECT event_stream_ordering FROM current_state_resets" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering ASC" @@ -1143,7 +1143,7 @@ class EventsStore(SQLBaseStore): state_resets = txn.fetchall() sql = ( - "SELECT -event_stream_ordering, event_id, state_group" + "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" diff --git a/tests/replication/slave/__init__.py b/tests/replication/slave/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/__init__.py b/tests/replication/slave/storage/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/slave/storage/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py new file mode 100644 index 0000000000..0f525a8943 --- /dev/null +++ b/tests/replication/slave/storage/_base.py @@ -0,0 +1,57 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from tests import unittest + +from synapse.replication.slave.storage.events import SlavedEventStore + +from mock import Mock, NonCallableMock +from tests.utils import setup_test_homeserver +from synapse.replication.resource import ReplicationResource + + +class BaseSlavedStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "blue", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.replication = ReplicationResource(self.hs) + + self.master_store = self.hs.get_datastore() + self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.event_id = 0 + + @defer.inlineCallbacks + def replicate(self): + streams = self.slaved_store.stream_positions() + result = yield self.replication.replicate(streams, 100) + yield self.slaved_store.process_replication(result) + + @defer.inlineCallbacks + def check(self, method, args, expected_result=None): + master_result = yield getattr(self.master_store, method)(*args) + slaved_result = yield getattr(self.slaved_store, method)(*args) + self.assertEqual(master_result, slaved_result) + if expected_result is not None: + self.assertEqual(master_result, expected_result) + self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py new file mode 100644 index 0000000000..c30c7c6063 --- /dev/null +++ b/tests/replication/slave/storage/test_events.py @@ -0,0 +1,114 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.types import UserID +from synapse.events import FrozenEvent +from synapse.events.snapshot import EventContext + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +USER = UserID.from_string(USER_ID) +OUTLIER = {"outlier": True} +ROOM_ID = "!room:blue" + + +class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + + @defer.inlineCallbacks + def test_room_name_and_aliases(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.persist(type="m.room.name", key="", name="name1") + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#1:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) + ) + + # Set the room name. + yield self.persist(type="m.room.name", key="", name="name2") + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) + ) + + # Set the room aliases. + yield self.persist( + type="m.room.aliases", key="blue", aliases=["#2:blue"] + ) + yield self.replicate() + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) + ) + + # Leave and join the room clobbering the state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + + yield self.check( + "get_room_name_and_aliases", (ROOM_ID,), (None, []) + ) + + event_id = 0 + + @defer.inlineCallbacks + def persist( + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, + internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content + ): + if depth is None: + depth = self.event_id + + event_dict = { + "sender": sender, + "type": type, + "content": content, + "event_id": "$%d:blue" % (self.event_id,), + "room_id": room_id, + "depth": depth, + "origin_server_ts": self.event_id, + "prev_events": prev_events, + "auth_events": auth_events, + } + if key is not None: + event_dict["state_key"] = key + event_dict["prev_state"] = prev_state + + event = FrozenEvent(event_dict, internal_metadata_dict=internal) + + self.event_id += 1 + + context = EventContext(current_state=state) + + if backfill: + yield self.master_store.persist_events( + [(event, context)], backfilled=True + ) + else: + yield self.master_store.persist_event( + event, context, current_state=reset_state + ) + defer.returnValue(event) -- cgit 1.5.1 From 6bfec56796132520ad864ad00f8dd6773512f9f4 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 6 Apr 2016 16:17:15 +0100 Subject: Test that room membership is replicated --- synapse/replication/slave/storage/events.py | 7 +-- tests/replication/slave/storage/test_events.py | 71 +++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 11 deletions(-) (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 68b924e37b..680dc89536 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -71,9 +71,6 @@ class SlavedEventStore(BaseSlavedStore): get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ - _get_rooms_for_user_where_membership_is_txn = ( - DataStore._get_rooms_for_user_where_membership_is_txn.__func__ - ) get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -95,6 +92,10 @@ class SlavedEventStore(BaseSlavedStore): _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ + _get_rooms_for_user_where_membership_is_txn = ( + DataStore._get_rooms_for_user_where_membership_is_txn.__func__ + ) + _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index c30c7c6063..351d777fb2 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,14 +14,14 @@ from ._base import BaseSlavedStoreTestCase -from synapse.types import UserID from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext +from synapse.storage.roommember import RoomsForUser from twisted.internet import defer USER_ID = "@feeling:blue" -USER = UserID.from_string(USER_ID) +USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" @@ -69,16 +69,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_room_name_and_aliases", (ROOM_ID,), (None, []) ) + @defer.inlineCallbacks + def test_room_members(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Join the room. + join = yield self.persist(type="m.room.member", key=USER_ID, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + + # Leave the room. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID,), []) + yield self.check("get_users_in_room", (ROOM_ID,), []) + + # Add some other user to the room. + join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join") + yield self.replicate() + yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser( + room_id=ROOM_ID, + sender=USER_ID, + membership="join", + event_id=join.event_id, + stream_ordering=join.internal_metadata.stream_ordering, + )]) + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) + + # Join the room clobbering the state. + # This should remove any evidence of the other user being in the room. + yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) + yield self.check("get_rooms_for_user", (USER_ID_2,), []) + event_id = 0 @defer.inlineCallbacks def persist( - self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, - internal={}, - state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], - **content + self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, + state=None, reset_state=False, backfill=False, + depth=None, prev_events=[], auth_events=[], prev_state=[], + **content ): + """ + Returns: + synapse.events.FrozenEvent: The event that was persisted. + """ if depth is None: depth = self.event_id @@ -103,12 +153,17 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): context = EventContext(current_state=state) + ordering = None if backfill: yield self.master_store.persist_events( [(event, context)], backfilled=True ) else: - yield self.master_store.persist_event( + ordering, _ = yield self.master_store.persist_event( event, context, current_state=reset_state ) + + if ordering: + event.internal_metadata.stream_ordering = ordering + defer.returnValue(event) -- cgit 1.5.1 From 60ec9793fb44ad445dd1233594957baeede60e4f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 13:17:56 +0100 Subject: Add tests for get_latest_event_ids_in_room and get_current_state --- synapse/events/__init__.py | 9 ++++ synapse/replication/slave/storage/events.py | 5 +++ tests/replication/slave/storage/test_events.py | 62 ++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) (limited to 'tests/replication') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 13154b1723..81e2126202 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,6 +36,10 @@ class _EventInternalMetadata(object): def is_invite_from_remote(self): return getattr(self, "invite_from_remote", False) + def __eq__(self, other): + "Equality check for unit tests." + return self.__dict__ == other.__dict__ + def _event_dict_property(key): def getter(self): @@ -180,3 +184,8 @@ class FrozenEvent(EventBase): self.get("type", None), self.get("state_key", None), ) + + def __eq__(self, other): + """Equality check for unit tests. Compares internal_metadata as well + as the event fields""" + return self.__dict__ == other.__dict__ diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 680dc89536..707ddd248a 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -89,8 +89,11 @@ class SlavedEventStore(BaseSlavedStore): _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _parse_events_txn = DataStore._parse_events_txn.__func__ _get_events_txn = DataStore._get_events_txn.__func__ + _enqueue_events = DataStore._enqueue_events.__func__ + _do_fetch = DataStore._do_fetch.__func__ _fetch_events_txn = DataStore._fetch_events_txn.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ + _get_event_from_row = DataStore._get_event_from_row.__func__ _get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ @@ -158,6 +161,8 @@ class SlavedEventStore(BaseSlavedStore): self._invalidate_get_event_cache(event.event_id) + self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 351d777fb2..d5d0ef1148 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -116,6 +116,68 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) yield self.check("get_rooms_for_user", (USER_ID_2,), []) + @defer.inlineCallbacks + def test_get_latest_event_ids_in_room(self): + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id] + ) + + join = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + prev_events=[(create.event_id, {})], + ) + yield self.replicate() + yield self.check( + "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] + ) + + @defer.inlineCallbacks + def test_get_current_state(self): + # Create the room. + create = yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] + ) + + # Join the room. + join1 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join1] + ) + + # Add some other user to the room. + join2 = yield self.persist( + type="m.room.member", key=USER_ID_2, membership="join", + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [join2] + ) + + # Leave the room, then rejoin the room clobbering state. + yield self.persist(type="m.room.member", key=USER_ID, membership="leave") + join3 = yield self.persist( + type="m.room.member", key=USER_ID, membership="join", + reset_state=[create] + ) + yield self.replicate() + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), + [] + ) + yield self.check( + "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), + [join3] + ) + event_id = 0 @defer.inlineCallbacks -- cgit 1.5.1 From 57fa1801c336603c6c710d6d868aaae596a7f5b8 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 16:41:37 +0100 Subject: Add sensible __eq__ operators inside the tests. Rather than adding them globally. This limits the changes to only affect the tests. --- synapse/events/__init__.py | 9 -------- tests/replication/slave/storage/test_events.py | 29 +++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 10 deletions(-) (limited to 'tests/replication') diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 81e2126202..13154b1723 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,10 +36,6 @@ class _EventInternalMetadata(object): def is_invite_from_remote(self): return getattr(self, "invite_from_remote", False) - def __eq__(self, other): - "Equality check for unit tests." - return self.__dict__ == other.__dict__ - def _event_dict_property(key): def getter(self): @@ -184,8 +180,3 @@ class FrozenEvent(EventBase): self.get("type", None), self.get("state_key", None), ) - - def __eq__(self, other): - """Equality check for unit tests. Compares internal_metadata as well - as the event fields""" - return self.__dict__ == other.__dict__ diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index d5d0ef1148..9af62702b3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -14,20 +14,47 @@ from ._base import BaseSlavedStoreTestCase -from synapse.events import FrozenEvent +from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.storage.roommember import RoomsForUser from twisted.internet import defer + USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" OUTLIER = {"outlier": True} ROOM_ID = "!room:blue" +def dict_equals(self, other): + return self.__dict__ == other.__dict__ + + +def patch__eq__(cls): + eq = getattr(cls, "__eq__", None) + cls.__eq__ = dict_equals + + def unpatch(): + if eq is not None: + cls.__eq__ = eq + return unpatch + + class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + def setUp(self): + # Patch up the equality operator for events so that we can check + # whether lists of events match using assertEquals + self.unpatches = [ + patch__eq__(_EventInternalMetadata), + patch__eq__(FrozenEvent), + ] + return super(SlavedEventStoreTestCase, self).setUp() + + def tearDown(self): + [unpatch() for unpatch in self.unpatches] + @defer.inlineCallbacks def test_room_name_and_aliases(self): create = yield self.persist(type="m.room.create", key="", creator=USER_ID) -- cgit 1.5.1 From ceb599e789ef34a3733a99d17701a75fc5409d17 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 7 Apr 2016 16:26:52 +0100 Subject: Add tests for redactions --- synapse/replication/slave/storage/events.py | 4 +- synapse/storage/util/id_generators.py | 2 +- tests/replication/slave/storage/_base.py | 2 +- tests/replication/slave/storage/test_events.py | 51 +++++++++++++++++++++++++- 4 files changed, 54 insertions(+), 5 deletions(-) (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 707ddd248a..cfc728a038 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): "_get_current_state_for_key" ] + get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( @@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() - result["backfilled"] = self._backfill_id_gen.get_current_token() + result["backfill"] = self._backfill_id_gen.get_current_token() return result def process_replication(self, result): @@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) - event = FrozenEvent(event_json, internal_metadata_dict=internal) self._invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f69f1cdad4..46cf93ff87 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -112,7 +112,7 @@ class StreamIdGenerator(object): self._current + self._step * (n + 1), self._step ) - self._current += n + self._current += n * self._step for next_id in next_ids: self._unfinished_ids.append(next_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 0f525a8943..983caafe8a 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): def check(self, method, args, expected_result=None): master_result = yield getattr(self.master_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args) - self.assertEqual(master_result, slaved_result) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) + self.assertEqual(master_result, slaved_result) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 9af62702b3..baa4a26eb5 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): [join3] ) + @defer.inlineCallbacks + def test_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + + @defer.inlineCallbacks + def test_backfilled_redactions(self): + yield self.persist(type="m.room.create", key="", creator=USER_ID) + yield self.persist(type="m.room.member", key=USER_ID, membership="join") + + msg = yield self.persist( + type="m.room.message", msgtype="m.text", body="Hello" + ) + yield self.replicate() + yield self.check("get_event", [msg.event_id], msg) + + redaction = yield self.persist( + type="m.room.redaction", redacts=msg.event_id, backfill=True + ) + yield self.replicate() + + msg_dict = msg.get_dict() + msg_dict["content"] = {} + msg_dict["unsigned"]["redacted_by"] = redaction.event_id + msg_dict["unsigned"]["redacted_because"] = redaction + redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) + yield self.check("get_event", [msg.event_id], redacted) + event_id = 0 @defer.inlineCallbacks def persist( self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, - depth=None, prev_events=[], auth_events=[], prev_state=[], + depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, **content ): """ @@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_dict["state_key"] = key event_dict["prev_state"] = prev_state + if redacts is not None: + event_dict["redacts"] = redacts + event = FrozenEvent(event_dict, internal_metadata_dict=internal) self.event_id += 1 -- cgit 1.5.1 From e99365f6015af6dc0c2c107f47bda3760ff1153e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 19 Apr 2016 15:22:14 +0100 Subject: Replicate get_invited_rooms_for_user --- synapse/replication/slave/storage/events.py | 9 +++++++-- tests/replication/slave/storage/test_events.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index cfc728a038..82f171c257 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -68,6 +68,9 @@ class SlavedEventStore(BaseSlavedStore): _get_current_state_for_key = StateStore.__dict__[ "_get_current_state_for_key" ] + get_invited_rooms_for_user = RoomMemberStore.__dict__[ + "get_invited_rooms_for_user" + ] get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ @@ -82,6 +85,7 @@ class SlavedEventStore(BaseSlavedStore): get_room_events_stream_for_room = ( DataStore.get_room_events_stream_for_room.__func__ ) + _set_before_and_after = DataStore._set_before_and_after _get_events = DataStore._get_events.__func__ @@ -147,11 +151,11 @@ class SlavedEventStore(BaseSlavedStore): internal = json.loads(row[1]) event_json = json.loads(row[2]) event = FrozenEvent(event_json, internal_metadata_dict=internal) - self._invalidate_caches_for_event( + self.invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets ) - def _invalidate_caches_for_event(self, event, backfilled, reset_state): + def invalidate_caches_for_event(self, event, backfilled, reset_state): if reset_state: self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() @@ -182,6 +186,7 @@ class SlavedEventStore(BaseSlavedStore): # self._membership_stream_cache.entity_has_changed( # event.state_key, event.internal_metadata.stream_ordering # ) + self.get_invited_rooms_for_user.invalidate((event.state_key,)) if not event.is_state(): return diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index baa4a26eb5..88b8d08110 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -251,6 +251,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) yield self.check("get_event", [msg.event_id], redacted) + @defer.inlineCallbacks + def test_invites(self): + yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) + event = yield self.persist( + type="m.room.member", key=USER_ID_2, membership="invite" + ) + yield self.replicate() + yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser( + ROOM_ID, USER_ID, "invite", event.event_id, + event.internal_metadata.stream_ordering + )]) + event_id = 0 @defer.inlineCallbacks -- cgit 1.5.1 From 5bbd424ee0c983d305014df618fa1917ecd10d91 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 19 Apr 2016 17:11:44 +0100 Subject: Add a slaved receipts store --- synapse/replication/slave/storage/receipts.py | 61 ++++++++++++++++++++++++ tests/replication/slave/storage/_base.py | 4 +- tests/replication/slave/storage/test_events.py | 3 ++ tests/replication/slave/storage/test_receipts.py | 39 +++++++++++++++ 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 synapse/replication/slave/storage/receipts.py create mode 100644 tests/replication/slave/storage/test_receipts.py (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py new file mode 100644 index 0000000000..b55d5dfd08 --- /dev/null +++ b/synapse/replication/slave/storage/receipts.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker + +from synapse.storage import DataStore +from synapse.storage.receipts import ReceiptsStore + +# So, um, we want to borrow a load of functions intended for reading from +# a DataStore, but we don't want to take functions that either write to the +# DataStore or are cached and don't have cache invalidation logic. +# +# Rather than write duplicate versions of those functions, or lift them to +# a common base class, we going to grab the underlying __func__ object from +# the method descriptor on the DataStore and chuck them into our class. + + +class SlavedReceiptsStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedReceiptsStore, self).__init__(db_conn, hs) + + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + + get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] + + get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ + get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + + def stream_positions(self): + result = super(SlavedReceiptsStore, self).stream_positions() + result["receipts"] = self._receipts_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("receipts") + if stream: + self._receipts_id_gen.advance(stream["position"]) + for row in stream["rows"]: + room_id, receipt_type, user_id = row[1:4] + self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) + + return super(SlavedReceiptsStore, self).process_replication(result) + + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 983caafe8a..1f13cd0bc0 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,8 +15,6 @@ from twisted.internet import defer from tests import unittest -from synapse.replication.slave.storage.events import SlavedEventStore - from mock import Mock, NonCallableMock from tests.utils import setup_test_homeserver from synapse.replication.resource import ReplicationResource @@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): self.replication = ReplicationResource(self.hs) self.master_store = self.hs.get_datastore() - self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 @defer.inlineCallbacks diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index baa4a26eb5..66f166047e 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext +from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser from twisted.internet import defer @@ -43,6 +44,8 @@ def patch__eq__(cls): class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): + STORE_TYPE = SlavedEventStore + def setUp(self): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py new file mode 100644 index 0000000000..6624fe4eea --- /dev/null +++ b/tests/replication/slave/storage/test_receipts.py @@ -0,0 +1,39 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStoreTestCase + +from synapse.replication.slave.storage.receipts import SlavedReceiptsStore + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +ROOM_ID = "!room:blue" +EVENT_ID = "$event:blue" + + +class SlavedReceiptTestCase(BaseSlavedStoreTestCase): + + STORE_TYPE = SlavedReceiptsStore + + @defer.inlineCallbacks + def test_receipt(self): + yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + yield self.master_store.insert_receipt( + ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} + ) + yield self.replicate() + yield self.check("get_receipts_for_user", [USER_ID, "m.read"], { + ROOM_ID: EVENT_ID + }) -- cgit 1.5.1 From c0d8e0eb6375d5b8754db420af43733a642c7f38 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 21 Apr 2016 15:25:47 +0100 Subject: Replicate push actions --- synapse/replication/slave/storage/events.py | 14 +++++++++ tests/replication/slave/storage/test_events.py | 43 ++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 82f171c257..5f37ba6995 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -21,6 +21,7 @@ from synapse.storage import DataStore from synapse.storage.room import RoomStore from synapse.storage.roommember import RoomMemberStore from synapse.storage.event_federation import EventFederationStore +from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.state import StateStore from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -71,7 +72,16 @@ class SlavedEventStore(BaseSlavedStore): get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] + get_unread_event_push_actions_by_room_for_user = ( + EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] + ) + get_unread_push_actions_for_user_in_range = ( + DataStore.get_unread_push_actions_for_user_in_range.__func__ + ) + get_push_action_users_in_range = ( + DataStore.get_push_action_users_in_range.__func__ + ) get_event = DataStore.get_event.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ @@ -167,6 +177,10 @@ class SlavedEventStore(BaseSlavedStore): self.get_latest_event_ids_in_room.invalidate((event.room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate_many( + (event.room_id,) + ) + if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 41a626cf70..17587fda00 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -266,6 +266,47 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.internal_metadata.stream_ordering )]) + @defer.inlineCallbacks + def test_push_actions_for_user(self): + yield self.persist(type="m.room.create", creator=USER_ID) + yield self.persist(type="m.room.join", key=USER_ID, membership="join") + yield self.persist( + type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" + ) + event1 = yield self.persist( + type="m.room.message", msgtype="m.text", body="hello" + ) + yield self.replicate() + yield self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2, event1.event_id], + {"highlight_count": 0, "notify_count": 0} + ) + + yield self.persist( + type="m.room.message", msgtype="m.text", body="world", + push_actions=[(USER_ID_2, ["notify"])], + ) + yield self.replicate() + yield self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2, event1.event_id], + {"highlight_count": 0, "notify_count": 1} + ) + + yield self.persist( + type="m.room.message", msgtype="m.text", body="world", + push_actions=[(USER_ID_2, [ + "notify", {"set_tweak": "highlight", "value": True} + ])], + ) + yield self.replicate() + yield self.check( + "get_unread_event_push_actions_by_room_for_user", + [ROOM_ID, USER_ID_2, event1.event_id], + {"highlight_count": 1, "notify_count": 2} + ) + event_id = 0 @defer.inlineCallbacks @@ -273,6 +314,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, state=None, reset_state=False, backfill=False, depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, + push_actions=[], **content ): """ @@ -305,6 +347,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.event_id += 1 context = EventContext(current_state=state) + context.push_actions = push_actions ordering = None if backfill: -- cgit 1.5.1 From 3abab26458cc9fe8a77d5ccee664e87ce407ed58 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 13 May 2016 15:34:06 +0100 Subject: Add a slaved datastore for account data --- synapse/replication/slave/storage/account_data.py | 61 ++++++++++++++++++++++ .../replication/slave/storage/test_account_data.py | 56 ++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 synapse/replication/slave/storage/account_data.py create mode 100644 tests/replication/slave/storage/test_account_data.py (limited to 'tests/replication') diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py new file mode 100644 index 0000000000..f59b0eabbc --- /dev/null +++ b/synapse/replication/slave/storage/account_data.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataStore + + +class SlavedAccountDataStore(BaseSlavedStore): + + def __init__(self, db_conn, hs): + super(SlavedAccountDataStore, self).__init__(db_conn, hs) + self._account_data_id_gen = SlavedIdTracker( + db_conn, "account_data_max_stream_id", "stream_id", + ) + + get_global_account_data_by_type_for_users = ( + AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] + ) + + get_global_account_data_by_type_for_user = ( + AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] + ) + + def stream_positions(self): + result = super(SlavedAccountDataStore, self).stream_positions() + position = self._account_data_id_gen.get_current_token() + result["user_account_data"] = position + result["room_account_data"] = position + result["tag_account_data"] = position + return result + + def process_replication(self, result): + stream = result.get("user_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + user_id, data_type = row[1:3] + self.get_global_account_data_by_type_for_user.invalidate( + (data_type, user_id,) + ) + + stream = result.get("room_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) + + stream = result.get("tag_account_data") + if stream: + self._account_data_id_gen.advance(int(stream["position"])) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py new file mode 100644 index 0000000000..da54d478ce --- /dev/null +++ b/tests/replication/slave/storage/test_account_data.py @@ -0,0 +1,56 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._base import BaseSlavedStoreTestCase + +from synapse.replication.slave.storage.account_data import SlavedAccountDataStore + +from twisted.internet import defer + +USER_ID = "@feeling:blue" +TYPE = "my.type" + + +class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): + + STORE_TYPE = SlavedAccountDataStore + + @defer.inlineCallbacks + def test_user_account_data(self): + yield self.master_store.add_account_data_for_user( + USER_ID, TYPE, {"a": 1} + ) + yield self.replicate() + yield self.check( + "get_global_account_data_by_type_for_user", + [TYPE, USER_ID], {"a": 1} + ) + yield self.check( + "get_global_account_data_by_type_for_users", + [TYPE, [USER_ID]], {USER_ID: {"a": 1}} + ) + + yield self.master_store.add_account_data_for_user( + USER_ID, TYPE, {"a": 2} + ) + yield self.replicate() + yield self.check( + "get_global_account_data_by_type_for_user", + [TYPE, USER_ID], {"a": 2} + ) + yield self.check( + "get_global_account_data_by_type_for_users", + [TYPE, [USER_ID]], {USER_ID: {"a": 2}} + ) -- cgit 1.5.1 From 3b86ecfa7965f4d29e0f5ce8fb663e5f018adf89 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 16 May 2016 18:56:37 +0100 Subject: Move the presence handler out of the Handlers object --- synapse/handlers/__init__.py | 2 -- synapse/handlers/events.py | 2 +- synapse/handlers/message.py | 4 ++-- synapse/handlers/presence.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/replication/resource.py | 2 +- synapse/rest/client/v1/presence.py | 20 ++++++++++++++------ synapse/rest/client/v1/room.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/server.py | 5 +++++ tests/replication/test_resource.py | 2 +- 12 files changed, 29 insertions(+), 18 deletions(-) (limited to 'tests/replication') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index f4dbf47c1d..60e31b68ff 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -24,7 +24,6 @@ from .message import MessageHandler from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler -from .presence import PresenceHandler from .directory import DirectoryHandler from .typing import TypingNotificationHandler from .admin import AdminHandler @@ -53,7 +52,6 @@ class Handlers(object): self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) - self.presence_handler = PresenceHandler(hs) self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f25a252523..3a3a1257d3 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -58,7 +58,7 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) - presence_handler = self.hs.get_handlers().presence_handler + presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 13154edb78..c4e38d0faa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -236,7 +236,7 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Message: - presence = self.hs.get_handlers().presence_handler + presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) def deduplicate_state_event(self, event, context): @@ -674,7 +674,7 @@ class MessageHandler(BaseHandler): and m.content["membership"] == Membership.JOIN ] - presence_handler = self.hs.get_handlers().presence_handler + presence_handler = self.hs.get_presence_handler() @defer.inlineCallbacks def get_presence(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index a8529cce42..8aaaec7030 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -860,7 +860,7 @@ class PresenceEventSource(object): from_key = int(from_key) room_ids = room_ids or [] - presence = self.hs.get_handlers().presence_handler + presence = self.hs.get_presence_handler() stream_change_cache = self.store.presence_stream_cache if not room_ids: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 921215469f..b30102f472 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -639,7 +639,7 @@ class SyncHandler(BaseHandler): # For each newly joined room, we want to send down presence of # existing users. - presence_handler = self.hs.get_handlers().presence_handler + presence_handler = self.hs.get_presence_handler() extra_presence_users = set() for room_id in newly_joined_rooms: users = yield self.store.get_users_in_room(event.room_id) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 0e983ae7fa..b0e7a17670 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -109,7 +109,7 @@ class ReplicationResource(Resource): self.version_string = hs.version_string self.store = hs.get_datastore() self.sources = hs.get_event_sources() - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_handlers().typing_notification_handler self.notifier = hs.notifier self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 27d9ed586b..eafdce865e 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,20 +30,24 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/(?P[^/]*)/status") + def __init__(self, hs): + super(PresenceStatusRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.handlers.presence_handler.is_visible( + allowed = yield self.presence_handler.is_visible( observed_user=user, observer_user=requester.user, ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.handlers.presence_handler.get_state(target_user=user) + state = yield self.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state(user, state) + yield self.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/presence/list/(?P[^/]*)") + def __init__(self, hs): + super(PresenceListRestServlet, self).__init__(hs) + self.presence_handler = hs.get_presence_handler() + @defer.inlineCallbacks def on_GET(self, request, user_id): requester = yield self.auth.get_user_by_req(request) @@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if requester.user != user: raise SynapseError(400, "Cannot get another user's presence list") - presence = yield self.handlers.presence_handler.get_presence_list( + presence = yield self.presence_handler.get_presence_list( observer_user=user, accepted=True ) @@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue invited_user = UserID.from_string(u) - yield self.handlers.presence_handler.send_presence_invite( + yield self.presence_handler.send_presence_invite( observer_user=user, observed_user=invited_user ) @@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue dropped_user = UserID.from_string(u) - yield self.handlers.presence_handler.drop( + yield self.presence_handler.drop( observer_user=user, observed_user=dropped_user ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b223fb7e5f..9c89442ce6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -570,7 +570,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomTypingRestServlet, self).__init__(hs) - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index b831d8c95e..891cef99c6 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,7 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 60d3dc4030..812abe22b1 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -83,7 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() - self.presence_handler = hs.get_handlers().presence_handler + self.presence_handler = hs.get_presence_handler() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/server.py b/synapse/server.py index ee138de756..6d01b68bd4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -27,6 +27,7 @@ from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFa from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers +from synapse.handlers.presence import PresenceHandler from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock @@ -78,6 +79,7 @@ class HomeServer(object): 'auth', 'rest_servlet_factory', 'state_handler', + 'presence_handler', 'notifier', 'distributor', 'client_resource', @@ -164,6 +166,9 @@ class HomeServer(object): def build_state_handler(self): return StateHandler(self) + def build_presence_handler(self): + return PresenceHandler(self) + def build_event_sources(self): return EventSources(self) diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index b1dd7b4a74..1258aaacb1 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -78,7 +78,7 @@ class ReplicationResourceCase(unittest.TestCase): @defer.inlineCallbacks def test_presence(self): get = self.get(presence="-1") - yield self.hs.get_handlers().presence_handler.set_state( + yield self.hs.get_presence_handler().set_state( self.user, {"presence": "online"} ) code, body = yield get -- cgit 1.5.1 From 0cb441fedd77b42f307745a441b804fee6386cb5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 17 May 2016 15:58:46 +0100 Subject: Move typing handler out of the Handlers object --- synapse/handlers/__init__.py | 2 -- synapse/handlers/typing.py | 33 +++++++++++++++------------------ synapse/replication/resource.py | 2 +- synapse/rest/client/v1/room.py | 7 +++---- synapse/server.py | 5 +++++ tests/handlers/test_typing.py | 10 +--------- tests/replication/test_resource.py | 2 +- tests/rest/client/v1/test_typing.py | 2 +- 8 files changed, 27 insertions(+), 36 deletions(-) (limited to 'tests/replication') diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index e1fc9a58ad..9442ae6f1d 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -25,7 +25,6 @@ from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler from .directory import DirectoryHandler -from .typing import TypingNotificationHandler from .admin import AdminHandler from .appservice import ApplicationServicesHandler from .auth import AuthHandler @@ -53,7 +52,6 @@ class Handlers(object): self.profile_handler = ProfileHandler(hs) self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) - self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) self.receipts_handler = ReceiptsHandler(hs) asapi = ApplicationServiceApi(hs) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index fca8d25c3f..d46f05f426 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from ._base import BaseHandler - from synapse.api.errors import SynapseError, AuthError from synapse.util.logcontext import PreserveLoggingContext from synapse.util.metrics import Measure @@ -35,12 +33,13 @@ logger = logging.getLogger(__name__) RoomMember = namedtuple("RoomMember", ("room_id", "user")) -class TypingNotificationHandler(BaseHandler): +class TypingHandler(object): def __init__(self, hs): - super(TypingNotificationHandler, self).__init__(hs) - self.store = hs.get_datastore() self.server_name = hs.config.server_name + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.notifier = hs.get_notifier() self.clock = hs.get_clock() @@ -68,7 +67,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") if target_user != auth_user: @@ -111,7 +110,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") if target_user != auth_user: @@ -133,7 +132,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.hs.is_mine(user): + if self.is_mine(user): member = RoomMember(room_id=room_id, user=user) yield self._stopped_typing(member) @@ -228,16 +227,14 @@ class TypingNotificationEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - self._handler = None - - def handler(self): - # Avoid cyclic dependency in handler setup - if not self._handler: - self._handler = self.hs.get_handlers().typing_notification_handler - return self._handler + # We can't call get_typing_handler here because there's a cycle: + # + # Typing -> Notifier -> TypingNotificationEventSource -> Typing + # + self.get_typing_handler = hs.get_typing_handler def _make_event_for(self, room_id): - typing = self.handler()._room_typing[room_id] + typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", "room_id": room_id, @@ -249,7 +246,7 @@ class TypingNotificationEventSource(object): def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) - handler = self.handler() + handler = self.get_typing_handler() events = [] for room_id in room_ids: @@ -263,7 +260,7 @@ class TypingNotificationEventSource(object): return events, handler._latest_room_serial def get_current_key(self): - return self.handler()._latest_room_serial + return self.get_typing_handler()._latest_room_serial def get_pagination_rows(self, user, pagination_config, key): return ([], pagination_config.from_key) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index b0e7a17670..847f212a3d 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -110,7 +110,7 @@ class ReplicationResource(Resource): self.store = hs.get_datastore() self.sources = hs.get_event_sources() self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_handlers().typing_notification_handler + self.typing_handler = hs.get_typing_handler() self.notifier = hs.notifier self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c89442ce6..cf478c6f79 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -571,6 +571,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): def __init__(self, hs): super(RoomTypingRestServlet, self).__init__(hs) self.presence_handler = hs.get_presence_handler() + self.typing_handler = hs.get_typing_handler() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): @@ -581,19 +582,17 @@ class RoomTypingRestServlet(ClientV1RestServlet): content = parse_json_object_from_request(request) - typing_handler = self.handlers.typing_notification_handler - yield self.presence_handler.bump_presence_active_time(requester.user) if content["typing"]: - yield typing_handler.started_typing( + yield self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=content.get("timeout", 30000), ) else: - yield typing_handler.stopped_typing( + yield self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, diff --git a/synapse/server.py b/synapse/server.py index 785a087452..01f828819f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -29,6 +29,7 @@ from synapse.api.auth import Auth from synapse.handlers import Handlers from synapse.handlers.presence import PresenceHandler from synapse.handlers.sync import SyncHandler +from synapse.handlers.typing import TypingHandler from synapse.state import StateHandler from synapse.storage import DataStore from synapse.util import Clock @@ -82,6 +83,7 @@ class HomeServer(object): 'state_handler', 'presence_handler', 'sync_handler', + 'typing_handler', 'notifier', 'distributor', 'client_resource', @@ -171,6 +173,9 @@ class HomeServer(object): def build_presence_handler(self): return PresenceHandler(self) + def build_typing_handler(self): + return TypingHandler(self) + def build_sync_handler(self): return SyncHandler(self) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index d38ca37d63..abb739ae52 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -25,8 +25,6 @@ from ..utils import ( ) from synapse.api.errors import AuthError -from synapse.handlers.typing import TypingNotificationHandler - from synapse.types import UserID @@ -49,11 +47,6 @@ def _make_edu_json(origin, edu_type, content): return json.dumps(_expect_edu("test", edu_type, content, origin=origin)) -class JustTypingNotificationHandlers(object): - def __init__(self, hs): - self.typing_notification_handler = TypingNotificationHandler(hs) - - class TypingNotificationsTestCase(unittest.TestCase): """Tests typing notifications to rooms.""" @defer.inlineCallbacks @@ -89,9 +82,8 @@ class TypingNotificationsTestCase(unittest.TestCase): http_client=self.mock_http_client, keyring=Mock(), ) - hs.handlers = JustTypingNotificationHandlers(hs) - self.handler = hs.get_handlers().typing_notification_handler + self.handler = hs.get_typing_handler() self.event_source = hs.get_event_sources().sources["typing"] diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 1258aaacb1..842e3d29d7 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -93,7 +93,7 @@ class ReplicationResourceCase(unittest.TestCase): def test_typing(self): room_id = yield self.create_room() get = self.get(typing="-1") - yield self.hs.get_handlers().typing_notification_handler.started_typing( + yield self.hs.get_typing_handler().started_typing( self.user, self.user, room_id, timeout=2 ) code, body = yield get diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index d0037a53ef..467f253ef6 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -106,7 +106,7 @@ class RoomTypingTestCase(RestTestCase): yield self.join(self.room_id, user="@jim:red") def tearDown(self): - self.hs.get_handlers().typing_notification_handler.tearDown() + self.hs.get_typing_handler().tearDown() @defer.inlineCallbacks def test_set_typing(self): -- cgit 1.5.1