diff --git a/changelog.d/16426.misc b/changelog.d/16426.misc
new file mode 100644
index 0000000000..208a007171
--- /dev/null
+++ b/changelog.d/16426.misc
@@ -0,0 +1 @@
+Refactor some code to simplify and better type receipts stream adjacent code.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 7de7bd3289..c200a45f3a 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -216,7 +216,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
@@ -326,7 +326,7 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: int,
users: Collection[Union[str, UserID]],
) -> None:
diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
index 7ed88a3611..87b428ab1c 100644
--- a/synapse/handlers/push_rules.py
+++ b/synapse/handlers/push_rules.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.push_rule import RuleNotFoundException
from synapse.synapse_rust.push import get_base_rule_ids
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -114,7 +114,9 @@ class PushRulesHandler:
user_id: the user ID the change is for.
"""
stream_id = self._main_store.get_max_push_rules_stream_id()
- self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
+ self._notifier.on_new_event(
+ StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
+ )
async def push_rules_for_user(
self, user: UserID
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a7a29b758b..69ac468f75 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -130,11 +130,10 @@ class ReceiptsHandler:
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
- min_batch_id: Optional[int] = None
- max_batch_id: Optional[int] = None
+ receipts_persisted: List[ReadReceipt] = []
for receipt in receipts:
- res = await self.store.insert_receipt(
+ stream_id = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -143,30 +142,26 @@ class ReceiptsHandler:
receipt.data,
)
- if not res:
- # res will be None if this receipt is 'old'
+ if stream_id is None:
+ # stream_id will be None if this receipt is 'old'
continue
- stream_id, max_persisted_id = res
+ receipts_persisted.append(receipt)
- if min_batch_id is None or stream_id < min_batch_id:
- min_batch_id = stream_id
- if max_batch_id is None or max_persisted_id > max_batch_id:
- max_batch_id = max_persisted_id
-
- # Either both of these should be None or neither.
- if min_batch_id is None or max_batch_id is None:
+ if not receipts_persisted:
# no new receipts
return False
- affected_room_ids = list({r.room_id for r in receipts})
+ max_batch_id = self.store.get_max_receipt_stream_id()
+
+ affected_room_ids = list({r.room_id for r in receipts_persisted})
self.notifier.on_new_event(
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
+ {r.user_id for r in receipts_persisted}
)
return True
diff --git a/synapse/notifier.py b/synapse/notifier.py
index fc39e5c963..99e7715896 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -126,7 +126,7 @@ class _NotifierUserStream:
def notify(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
) -> None:
@@ -454,7 +454,7 @@ class Notifier:
def on_new_event(
self,
- stream_key: str,
+ stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[StrCollection] = None,
@@ -655,30 +655,29 @@ class Notifier:
events: List[Union[JsonDict, EventBase]] = []
end_token = from_token
- for name, source in self.event_sources.sources.get_sources():
- keyname = "%s_key" % name
- before_id = getattr(before_token, keyname)
- after_id = getattr(after_token, keyname)
+ for keyname, source in self.event_sources.sources.get_sources():
+ before_id = before_token.get_field(keyname)
+ after_id = after_token.get_field(keyname)
if before_id == after_id:
continue
new_events, new_key = await source.get_new_events(
user=user,
- from_key=getattr(from_token, keyname),
+ from_key=from_token.get_field(keyname),
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)
- if name == "room":
+ if keyname == StreamKeyType.ROOM:
new_events = await filter_events_for_client(
self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
)
- elif name == "presence":
+ elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec()
new_events[:] = [
{
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9e3a98741a..9e5eb2a445 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 1710dd51b9..cf45fd09a8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -99,7 +99,7 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 50027680cb..725910a659 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -160,7 +160,7 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
- def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ def on_new_receipts(self) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 6517e3566f..15a2cc932f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -292,20 +292,12 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- async def on_new_receipts(
- self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
- ) -> None:
+ async def on_new_receipts(self, users_affected: StrCollection) -> None:
if not self.pushers:
# nothing to do here.
return
try:
- # Need to subtract 1 from the minimum because the lower bound here
- # is not inclusive
- users_affected = await self.store.get_users_sent_receipts_between(
- min_stream_id - 1, max_stream_id
- )
-
for u in users_affected:
# Don't push if the user account has expired
expired = await self._account_validity_handler.is_user_expired(u)
@@ -314,7 +306,7 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
- p.on_new_receipts(min_stream_id, max_stream_id)
+ p.on_new_receipts()
except Exception:
logger.exception("Exception in pusher on_new_receipts")
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index f4f2b29e96..d5337fe588 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -129,9 +129,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event(
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
)
- await self._pusher_pool.on_new_receipts(
- token, token, {row.room_id for row in rows}
- )
+ await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index d01f28cc80..bc7c6a6346 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
- StreamKeyType.ROOM: room_key,
+ StreamKeyType.ROOM.value: room_key,
}
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..3bab1024ea 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_ids: List[str],
thread_id: Optional[str],
data: dict,
- ) -> Optional[Tuple[int, int]]:
+ ) -> Optional[int]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
data,
)
- max_persisted_id = self._receipts_id_gen.get_current_token()
-
- return stream_id, max_persisted_id
+ return stream_id
async def _insert_graph_receipt(
self,
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index d7084d2358..609a0978a9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Iterator, Tuple
+from typing import TYPE_CHECKING, Sequence, Tuple
import attr
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import StreamToken
+from synapse.types import StreamKeyType, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -37,9 +37,14 @@ class _EventSourcesInner:
receipt: ReceiptEventSource
account_data: AccountDataEventSource
- def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
- for attribute in attr.fields(_EventSourcesInner):
- yield attribute.name, getattr(self, attribute.name)
+ def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
+ return [
+ (StreamKeyType.ROOM, self.room),
+ (StreamKeyType.PRESENCE, self.presence),
+ (StreamKeyType.TYPING, self.typing),
+ (StreamKeyType.RECEIPT, self.receipt),
+ (StreamKeyType.ACCOUNT_DATA, self.account_data),
+ ]
class EventSources:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 76b0e3e694..406d5b1611 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -22,8 +22,8 @@ from typing import (
Any,
ClassVar,
Dict,
- Final,
List,
+ Literal,
Mapping,
Match,
MutableMapping,
@@ -34,6 +34,7 @@ from typing import (
Type,
TypeVar,
Union,
+ overload,
)
import attr
@@ -649,20 +650,20 @@ class RoomStreamToken:
return "s%d" % (self.stream,)
-class StreamKeyType:
+class StreamKeyType(Enum):
"""Known stream types.
A stream is a list of entities ordered by an incrementing "stream token".
"""
- ROOM: Final = "room_key"
- PRESENCE: Final = "presence_key"
- TYPING: Final = "typing_key"
- RECEIPT: Final = "receipt_key"
- ACCOUNT_DATA: Final = "account_data_key"
- PUSH_RULES: Final = "push_rules_key"
- TO_DEVICE: Final = "to_device_key"
- DEVICE_LIST: Final = "device_list_key"
+ ROOM = "room_key"
+ PRESENCE = "presence_key"
+ TYPING = "typing_key"
+ RECEIPT = "receipt_key"
+ ACCOUNT_DATA = "account_data_key"
+ PUSH_RULES = "push_rules_key"
+ TO_DEVICE = "to_device_key"
+ DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
@@ -784,7 +785,7 @@ class StreamToken:
def room_stream_id(self) -> int:
return self.room_key.stream
- def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
+ def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
@@ -797,16 +798,44 @@ class StreamToken:
return new_token
new_token = self.copy_and_replace(key, new_value)
- new_id = int(getattr(new_token, key))
- old_id = int(getattr(self, key))
+ new_id = new_token.get_field(key)
+ old_id = self.get_field(key)
if old_id < new_id:
return new_token
else:
return self
- def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
- return attr.evolve(self, **{key: new_value})
+ def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
+ return attr.evolve(self, **{key.value: new_value})
+
+ @overload
+ def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
+ ...
+
+ @overload
+ def get_field(
+ self,
+ key: Literal[
+ StreamKeyType.ACCOUNT_DATA,
+ StreamKeyType.DEVICE_LIST,
+ StreamKeyType.PRESENCE,
+ StreamKeyType.PUSH_RULES,
+ StreamKeyType.RECEIPT,
+ StreamKeyType.TO_DEVICE,
+ StreamKeyType.TYPING,
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+ ],
+ ) -> int:
+ ...
+
+ @overload
+ def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ ...
+
+ def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+ """Returns the stream ID for the given key."""
+ return getattr(self, key.value)
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7e6cdd66a..8ce6ccf529 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- "receipt_key", 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
@@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- "receipt_key", 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
)
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice],
- stream_key="receipt_key",
+ stream_key=StreamKeyType.RECEIPT,
new_token=stream_token,
users=[self.exclusive_as_user],
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 95106ec8f3..3060bc9744 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 1)
events = self.get_success(
@@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.mock_federation_client.put_json.assert_called_once_with(
"farm",
@@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])]
+ )
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 1)
@@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.reactor.pump([16])
- self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])]
+ )
self.assertEqual(self.event_source.get_current_key(), 2)
events = self.get_success(
@@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls(
+ [call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])]
+ )
self.on_new_event.reset_mock()
self.assertEqual(self.event_source.get_current_key(), 3)
|