diff --git a/changelog.d/4942.bugfix b/changelog.d/4942.bugfix
new file mode 100644
index 0000000000..590d80d58f
--- /dev/null
+++ b/changelog.d/4942.bugfix
@@ -0,0 +1 @@
+Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 04d04a4457..0240b339b0 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object):
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> user_id
+ self.presence_changed = SortedDict() # Stream position -> list[user_id]
+
+ # Stores the destinations we need to explicitly send presence to about a
+ # given user.
+ # Stream position -> (user_id, destinations)
+ self.presence_destinations = SortedDict()
self.keyed_edu = {} # (destination, key) -> EDU
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
@@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "device_messages", "pos_time",
+ "edus", "device_messages", "pos_time", "presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object):
for user_id in uids
)
+ keys = self.presence_destinations.keys()
+ i = self.presence_destinations.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.presence_destinations[key]
+
+ user_ids.update(
+ user_id for user_id, _ in self.presence_destinations.values()
+ )
+
to_del = [
user_id for user_id in self.presence_map if user_id not in user_ids
]
@@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
+ def send_presence_to_destinations(self, states, destinations):
+ """As per FederationSender
+
+ Args:
+ states (list[UserPresenceState])
+ destinations (list[str])
+ """
+ for state in states:
+ pos = self._next_pos()
+ self.presence_map.update({state.user_id: state for state in states})
+ self.presence_destinations[pos] = (state.user_id, destinations)
+
+ self.notifier.on_new_replication_data()
+
def send_device_messages(self, destination):
"""As per FederationSender"""
pos = self._next_pos()
@@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object):
state=self.presence_map[user_id],
)))
+ # Fetch presence to send to destinations
+ i = self.presence_destinations.bisect_right(from_token)
+ j = self.presence_destinations.bisect_right(to_token) + 1
+
+ for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
+ rows.append((pos, PresenceDestinationsRow(
+ state=self.presence_map[user_id],
+ destinations=list(dests),
+ )))
+
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
j = self.keyed_edu_changed.bisect_right(to_token) + 1
@@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
+class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
+ "state", # UserPresenceState
+ "destinations", # list[str]
+))):
+ TypeId = "pd"
+
+ @staticmethod
+ def from_data(data):
+ return PresenceDestinationsRow(
+ state=UserPresenceState.from_dict(data["state"]),
+ destinations=data["dests"],
+ )
+
+ def to_data(self):
+ return {
+ "state": self.state.as_dict(),
+ "dests": self.destinations,
+ }
+
+ def add_to_buffer(self, buff):
+ buff.presence_destinations.append((self.state, self.destinations))
+
+
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
"key", # tuple(str) - the edu key passed to send_edu
"edu", # Edu
@@ -428,6 +489,7 @@ TypeToRow = {
Row.TypeId: Row
for Row in (
PresenceRow,
+ PresenceDestinationsRow,
KeyedEduRow,
EduRow,
DeviceRow,
@@ -437,6 +499,7 @@ TypeToRow = {
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
+ "presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
"device_destinations", # set of destinations
@@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows):
buff = ParsedFederationStreamData(
presence=[],
+ presence_destinations=[],
keyed_edus={},
edus={},
device_destinations=set(),
@@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
+ for state, destinations in buff.presence_destinations:
+ transaction_queue.send_presence_to_destinations(
+ states=[state], destinations=destinations,
+ )
+
for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 1dc041752b..4f0f939102 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -371,7 +371,7 @@ class FederationSender(object):
return
# First we queue up the new presence by user ID, so multiple presence
- # updates in quick successtion are correctly handled
+ # updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
self.pending_presence.update({
@@ -402,6 +402,23 @@ class FederationSender(object):
finally:
self._processing_pending_presence = False
+ def send_presence_to_destinations(self, states, destinations):
+ """Send the given presence states to the given destinations.
+
+ Args:
+ states (list[UserPresenceState])
+ destinations (list[str])
+ """
+
+ if not states or not self.hs.config.use_presence:
+ # No-op if presence is disabled.
+ return
+
+ for destination in destinations:
+ if destination == self.server_name:
+ continue
+ self._get_per_destination_queue(destination).send_presence(states)
+
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 37e87fc054..e85c49742d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -31,9 +31,11 @@ from prometheus_client import Counter
from twisted.internet import defer
-from synapse.api.constants import PresenceState
+import synapse.metrics
+from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -98,6 +100,7 @@ class PresenceHandler(object):
self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
+ self.server_name = hs.hostname
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
@@ -132,9 +135,6 @@ class PresenceHandler(object):
)
)
- distributor = hs.get_distributor()
- distributor.observe("user_joined_room", self.user_joined_room)
-
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
@@ -220,6 +220,15 @@ class PresenceHandler(object):
LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
lambda: len(self.wheel_timer))
+ # Used to handle sending of presence to newly joined users/servers
+ if hs.config.use_presence:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # Presence is best effort and quickly heals itself, so lets just always
+ # stream from the current state when we restart.
+ self._event_pos = self.store.get_current_events_token()
+ self._event_processing = False
+
@defer.inlineCallbacks
def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
@@ -751,31 +760,6 @@ class PresenceHandler(object):
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
- def user_joined_room(self, user, room_id):
- """Called (via the distributor) when a user joins a room. This funciton
- sends presence updates to servers, either:
- 1. the joining user is a local user and we send their presence to
- all servers in the room.
- 2. the joining user is a remote user and so we send presence for all
- local users in the room.
- """
- # We only need to send presence to servers that don't have it yet. We
- # don't need to send to local clients here, as that is done as part
- # of the event stream/sync.
- # TODO: Only send to servers not already in the room.
- if self.is_mine(user):
- state = yield self.current_state_for_user(user.to_string())
-
- self._push_to_remotes([state])
- else:
- user_ids = yield self.store.get_users_in_room(room_id)
- user_ids = list(filter(self.is_mine_id, user_ids))
-
- states = yield self.current_state_for_users(user_ids)
-
- self._push_to_remotes(list(states.values()))
-
- @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list.
"""
@@ -945,6 +929,140 @@ class PresenceHandler(object):
rows = yield self.store.get_all_presence_updates(last_id, current_id)
defer.returnValue(rows)
+ def notify_new_event(self):
+ """Called when new events have happened. Handles users and servers
+ joining rooms and require being sent presence.
+ """
+
+ if self._event_processing:
+ return
+
+ @defer.inlineCallbacks
+ def _process_presence():
+ assert not self._event_processing
+
+ self._event_processing = True
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._event_processing = False
+
+ run_as_background_process("presence.notify_new_event", _process_presence)
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "presence_delta"):
+ deltas = yield self.store.get_current_state_deltas(self._event_pos)
+ if not deltas:
+ return
+
+ yield self._handle_state_delta(deltas)
+
+ self._event_pos = deltas[-1]["stream_id"]
+
+ # Expose current event processing position to prometheus
+ synapse.metrics.event_processing_positions.labels("presence").set(
+ self._event_pos
+ )
+
+ @defer.inlineCallbacks
+ def _handle_state_delta(self, deltas):
+ """Process current state deltas to find new joins that need to be
+ handled.
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ prev_event_id = delta["prev_event_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ if typ != EventTypes.Member:
+ continue
+
+ event = yield self.store.get_event(event_id)
+ if event.content.get("membership") != Membership.JOIN:
+ # We only care about joins
+ continue
+
+ if prev_event_id:
+ prev_event = yield self.store.get_event(prev_event_id)
+ if prev_event.content.get("membership") == Membership.JOIN:
+ # Ignore changes to join events.
+ continue
+
+ yield self._on_user_joined_room(room_id, state_key)
+
+ @defer.inlineCallbacks
+ def _on_user_joined_room(self, room_id, user_id):
+ """Called when we detect a user joining the room via the current state
+ delta stream.
+
+ Args:
+ room_id (str)
+ user_id (str)
+
+ Returns:
+ Deferred
+ """
+
+ if self.is_mine_id(user_id):
+ # If this is a local user then we need to send their presence
+ # out to hosts in the room (who don't already have it)
+
+ # TODO: We should be able to filter the hosts down to those that
+ # haven't previously seen the user
+
+ state = yield self.current_state_for_user(user_id)
+ hosts = yield self.state.get_current_hosts_in_room(room_id)
+
+ # Filter out ourselves.
+ hosts = set(host for host in hosts if host != self.server_name)
+
+ self.federation.send_presence_to_destinations(
+ states=[state],
+ destinations=hosts,
+ )
+ else:
+ # A remote user has joined the room, so we need to:
+ # 1. Check if this is a new server in the room
+ # 2. If so send any presence they don't already have for
+ # local users in the room.
+
+ # TODO: We should be able to filter the users down to those that
+ # the server hasn't previously seen
+
+ # TODO: Check that this is actually a new server joining the
+ # room.
+
+ user_ids = yield self.state.get_current_user_in_room(room_id)
+ user_ids = list(filter(self.is_mine_id, user_ids))
+
+ states = yield self.current_state_for_users(user_ids)
+
+ # Filter out old presence, i.e. offline presence states where
+ # the user hasn't been active for a week. We can change this
+ # depending on what we want the UX to be, but at the least we
+ # should filter out offline presence where the state is just the
+ # default state.
+ now = self.clock.time_msec()
+ states = [
+ state for state in states.values()
+ if state.state != PresenceState.OFFLINE
+ or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
+ or state.status_msg is not None
+ ]
+
+ if states:
+ self.federation.send_presence_to_destinations(
+ states=states,
+ destinations=[get_domain_from_id(user_id)],
+ )
+
def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties.
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index fc2b646ba2..94c6080e34 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -16,7 +16,11 @@
from mock import Mock, call
-from synapse.api.constants import PresenceState
+from signedjson.key import generate_signing_key
+
+from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.events import room_version_to_event_format
+from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
@@ -26,7 +30,9 @@ from synapse.handlers.presence import (
handle_timeout,
handle_update,
)
+from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
+from synapse.types import UserID, get_domain_from_id
from tests import unittest
@@ -405,3 +411,171 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(state, new_state)
+
+
+class PresenceJoinTestCase(unittest.HomeserverTestCase):
+ """Tests remote servers get told about presence of users in the room when
+ they join and when new local users join.
+ """
+
+ user_id = "@test:server"
+
+ servlets = [room.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver(
+ "server", http_client=None,
+ federation_sender=Mock(),
+ )
+ return hs
+
+ def prepare(self, reactor, clock, hs):
+ self.federation_sender = hs.get_federation_sender()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.federation_handler = hs.get_handlers().federation_handler
+ self.presence_handler = hs.get_presence_handler()
+
+ # self.event_builder_for_2 = EventBuilderFactory(hs)
+ # self.event_builder_for_2.hostname = "test2"
+
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.auth = hs.get_auth()
+
+ # We don't actually check signatures in tests, so lets just create a
+ # random key to use.
+ self.random_signing_key = generate_signing_key("ver")
+
+ def test_remote_joins(self):
+ # We advance time to something that isn't 0, as we use 0 as a special
+ # value.
+ self.reactor.advance(1000000000000)
+
+ # Create a room with two local users
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.join(room_id, "@test2:server")
+
+ # Mark test2 as online, test will be offline with a last_active of 0
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ )
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ #
+ # Test that a new server gets told about existing presence
+ #
+
+ self.federation_sender.reset_mock()
+
+ # Add a new remote server to the room
+ self._add_new_user(room_id, "@alice:server2")
+
+ # We shouldn't have sent out any local presence *updates*
+ self.federation_sender.send_presence.assert_not_called()
+
+ # When new server is joined we send it the local users presence states.
+ # We expect to only see user @test2:server, as @test:server is offline
+ # and has a zero last_active_ts
+ expected_state = self.get_success(
+ self.presence_handler.current_state_for_user("@test2:server")
+ )
+ self.assertEqual(expected_state.state, PresenceState.ONLINE)
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=["server2"], states=[expected_state]
+ )
+
+ #
+ # Test that only the new server gets sent presence and not existing servers
+ #
+
+ self.federation_sender.reset_mock()
+ self._add_new_user(room_id, "@bob:server3")
+
+ self.federation_sender.send_presence.assert_not_called()
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=["server3"], states=[expected_state]
+ )
+
+ def test_remote_gets_presence_when_local_user_joins(self):
+ # We advance time to something that isn't 0, as we use 0 as a special
+ # value.
+ self.reactor.advance(1000000000000)
+
+ # Create a room with one local users
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Mark test as online
+ self.presence_handler.set_state(
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
+ )
+
+ # Mark test2 as online, test will be offline with a last_active of 0.
+ # Note we don't join them to the room yet
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
+ )
+
+ # Add servers to the room
+ self._add_new_user(room_id, "@alice:server2")
+ self._add_new_user(room_id, "@bob:server3")
+
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ #
+ # Test that when a local join happens remote servers get told about it
+ #
+
+ self.federation_sender.reset_mock()
+
+ # Join local user to room
+ self.helper.join(room_id, "@test2:server")
+
+ self.reactor.pump([0]) # Wait for presence updates to be handled
+
+ # We shouldn't have sent out any local presence *updates*
+ self.federation_sender.send_presence.assert_not_called()
+
+ # We expect to only send test2 presence to server2 and server3
+ expected_state = self.get_success(
+ self.presence_handler.current_state_for_user("@test2:server")
+ )
+ self.assertEqual(expected_state.state, PresenceState.ONLINE)
+ self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+ destinations=set(("server2", "server3")),
+ states=[expected_state]
+ )
+
+ def _add_new_user(self, room_id, user_id):
+ """Add new user to the room by creating an event and poking the federation API.
+ """
+
+ hostname = get_domain_from_id(user_id)
+
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ builder = EventBuilder(
+ state=self.state,
+ auth=self.auth,
+ store=self.store,
+ clock=self.clock,
+ hostname=hostname,
+ signing_key=self.random_signing_key,
+ format_version=room_version_to_event_format(room_version),
+ room_id=room_id,
+ type=EventTypes.Member,
+ sender=user_id,
+ state_key=user_id,
+ content={"membership": Membership.JOIN}
+ )
+
+ prev_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+
+ event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
+
+ # Check that it was successfully persisted.
+ self.get_success(self.store.get_event(event.event_id))
+ self.get_success(self.store.get_event(event.event_id))
|