diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index b531ba8540..d9e8f634ae 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
),
], any_order=True)
+ def test_online_to_online_last_active_noop(self):
+ wheel_timer = Mock()
+ user_id = "@foo:bar"
+ now = 5000000
+
+ prev_state = UserPresenceState.default(user_id)
+ prev_state = prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
+ currently_active=True,
+ )
+
+ new_state = prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=now,
+ )
+
+ state, persist_and_notify, federation_ping = handle_update(
+ prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ )
+
+ self.assertFalse(persist_and_notify)
+ self.assertTrue(federation_ping)
+ self.assertTrue(state.currently_active)
+ self.assertEquals(new_state.state, state.state)
+ self.assertEquals(new_state.status_msg, state.status_msg)
+ self.assertEquals(state.last_federation_update_ts, now)
+
+ self.assertEquals(wheel_timer.insert.call_count, 3)
+ wheel_timer.insert.assert_has_calls([
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_active_ts + IDLE_TIMER
+ ),
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+ ),
+ call(
+ now=now,
+ obj=user_id,
+ then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+ ),
+ ], any_order=True)
+
def test_online_to_online_last_active(self):
wheel_timer = Mock()
user_id = "@foo:bar"
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ab9899b7d5..b2957eef9f 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[])
+ self.state_handler = Mock()
hs = yield setup_test_homeserver(
"test",
@@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response",
"get_destination_retry_timings",
]),
+ state_handler=self.state_handler,
handlers=None,
notifier=mock_notifier,
resource_for_client=Mock(),
@@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
+ def get_current_user_in_room(room_id):
+ return set(str(u) for u in self.room_members)
+ self.state_handler.get_current_user_in_room = get_current_user_in_room
+
self.auth.check_joined_room = check_joined_room
# Some local users to test with
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index f33e6f60fb..44e859b5d1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -305,7 +305,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.event_id += 1
- context = EventContext(current_state=state)
+ if state is not None:
+ state_ids = {
+ key: e.event_id for key, e in state.items()
+ }
+ else:
+ state_ids = None
+
+ context = EventContext()
+ context.current_state_ids = state_ids
+ context.prev_state_ids = state_ids
context.push_actions = push_actions
ordering = None
diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py
index e70ac6f14d..b69832cc1b 100644
--- a/tests/replication/test_resource.py
+++ b/tests/replication/test_resource.py
@@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {})
@defer.inlineCallbacks
- def test_events_and_state(self):
- get = self.get(events="-1", state="-1", timeout="0")
+ def test_events(self):
+ get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
synapse.types.create_requester(self.user), {}
)
@@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group"
])
- self.assertEquals(body["state_groups"]["field_names"], [
- "position", "room_id", "event_id"
- ])
- self.assertEquals(body["state_group_state"]["field_names"], [
- "position", "type", "state_key", "event_id"
- ])
@defer.inlineCallbacks
def test_presence(self):
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 8853cbb5fc..4fe99ebc0b 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
- token = "t1-0_0_0_0_0_0"
+ token = "t1-0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
- token = "s0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 27b2b3d123..1be7d932f6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase):
)
)]
)
-
- @defer.inlineCallbacks
- def test_room_hosts(self):
- yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should still have just one host after second join from it
- yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should now have two hosts after join from other host
- yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN)
-
- self.assertEquals(
- {"test", "elsewhere"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should still have both hosts
- yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE)
-
- self.assertEquals(
- {"test", "elsewhere"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
-
- # Should have only one host after other leaves
- yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE)
-
- self.assertEquals(
- {"test"},
- (yield self.store.get_joined_hosts_for_room(self.room.to_string()))
- )
diff --git a/tests/test_state.py b/tests/test_state.py
index 1a11bbcee0..6454f994e3 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,9 +67,11 @@ class StateGroupStore(object):
self._event_to_state_group = {}
self._group_to_state = {}
+ self._event_id_to_event = {}
+
self._next_group = 1
- def get_state_groups(self, room_id, event_ids):
+ def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
@@ -79,22 +81,23 @@ class StateGroupStore(object):
return defer.succeed(groups)
def store_state_groups(self, event, context):
- if context.current_state is None:
+ if context.current_state_ids is None:
return
- state_events = context.current_state
-
- if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ state_events = dict(context.current_state_ids)
- state_group = context.state_group
- if not state_group:
- state_group = self._next_group
- self._next_group += 1
+ self._group_to_state[context.state_group] = state_events
+ self._event_to_state_group[event.event_id] = context.state_group
- self._group_to_state[state_group] = state_events.values()
+ def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id] for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
- self._event_to_state_group[event.event_id] = state_group
+ def register_events(self, events):
+ for e in events:
+ self._event_id_to_event[e.event_id] = e
class DictObj(dict):
@@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = Mock(
spec_set=[
- "get_state_groups",
+ "get_state_groups_ids",
"add_event_hashes",
+ "get_events",
+ "get_next_state_group",
]
)
hs = Mock(spec_set=[
@@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
+ self.store.get_next_state_group.side_effect = Mock
+
self.state = StateHandler(hs)
self.event_id = 0
@@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {}
@@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context)
context_store[event.event_id] = context
- self.assertEqual(2, len(context_store["D"].current_state))
+ self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "C"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e_id for e_id in context_store["D"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "B", "C"},
- {e.event_id for e in context_store["E"].current_state.values()}
+ {e for e in context_store["E"].prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
graph = Graph(nodes, edges)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e for e in context_store["D"].prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state), set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.current_state_ids.values())
)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state),
- set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.prev_state_ids.values())
)
- self.assertIsNone(context.state_group)
-
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event")
@@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.prev_state_ids.values())
)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
@@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
@@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
- self.assertIsNone(context.state_group)
+ self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
@@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ )
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
@@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1),
]
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ )
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1"
group_name_2 = "group_name_2"
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
+ self.store.get_state_groups_ids.return_value = {
+ group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+ group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
}
return self.state.compute_event_context(event)
|