diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
new file mode 100644
index 0000000000..f2d436ddc3
--- /dev/null
+++ b/synapse/types/__init__.py
@@ -0,0 +1,928 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import abc
+import re
+import string
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ ClassVar,
+ Dict,
+ List,
+ Mapping,
+ Match,
+ MutableMapping,
+ NoReturn,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import attr
+from frozendict import frozendict
+from signedjson.key import decode_verify_key_bytes
+from signedjson.types import VerifyKey
+from typing_extensions import Final, TypedDict
+from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import (
+ IReactorCore,
+ IReactorPluggableNameResolver,
+ IReactorSSL,
+ IReactorTCP,
+ IReactorThreads,
+ IReactorTime,
+)
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.util.cancellation import cancellable
+from synapse.util.stringutils import parse_and_validate_server_name
+
+if TYPE_CHECKING:
+ from synapse.appservice.api import ApplicationService
+ from synapse.storage.databases.main import DataStore, PurgeEventsStore
+ from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateKey = Tuple[str, str]
+StateMap = Mapping[StateKey, T]
+MutableStateMap = MutableMapping[StateKey, T]
+
+# JSON types. These could be made stronger, but will do for now.
+# A JSON-serialisable dict.
+JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
+# A JSON-serialisable object.
+JsonSerializable = object
+
+
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+ IReactorTCP,
+ IReactorSSL,
+ IReactorPluggableNameResolver,
+ IReactorTime,
+ IReactorCore,
+ IReactorThreads,
+ Interface,
+):
+ """The interfaces necessary for Synapse to function."""
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class Requester:
+ """
+ Represents the user making a request
+
+ Attributes:
+ user: id of the user making the request
+ access_token_id: *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request has been shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
+ """
+
+ user: "UserID"
+ access_token_id: Optional[int]
+ is_guest: bool
+ shadow_banned: bool
+ device_id: Optional[str]
+ app_service: Optional["ApplicationService"]
+ authenticated_entity: str
+
+ def serialize(self) -> Dict[str, Any]:
+ """Converts self to a type that can be serialized as JSON, and then
+ deserialized by `deserialize`
+
+ Returns:
+ dict
+ """
+ return {
+ "user_id": self.user.to_string(),
+ "access_token_id": self.access_token_id,
+ "is_guest": self.is_guest,
+ "shadow_banned": self.shadow_banned,
+ "device_id": self.device_id,
+ "app_server_id": self.app_service.id if self.app_service else None,
+ "authenticated_entity": self.authenticated_entity,
+ }
+
+ @staticmethod
+ def deserialize(
+ store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
+ ) -> "Requester":
+ """Converts a dict that was produced by `serialize` back into a
+ Requester.
+
+ Args:
+ store: Used to convert AS ID to AS object
+ input: A dict produced by `serialize`
+
+ Returns:
+ Requester
+ """
+ appservice = None
+ if input["app_server_id"]:
+ appservice = store.get_app_service_by_id(input["app_server_id"])
+
+ return Requester(
+ user=UserID.from_string(input["user_id"]),
+ access_token_id=input["access_token_id"],
+ is_guest=input["is_guest"],
+ shadow_banned=input["shadow_banned"],
+ device_id=input["device_id"],
+ app_service=appservice,
+ authenticated_entity=input["authenticated_entity"],
+ )
+
+
+def create_requester(
+ user_id: Union[str, "UserID"],
+ access_token_id: Optional[int] = None,
+ is_guest: bool = False,
+ shadow_banned: bool = False,
+ device_id: Optional[str] = None,
+ app_service: Optional["ApplicationService"] = None,
+ authenticated_entity: Optional[str] = None,
+) -> Requester:
+ """
+ Create a new ``Requester`` object
+
+ Args:
+ user_id: id of the user making the request
+ access_token_id: *ID* of the access token used for this
+ request, or None if it came via the appservice API or similar
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request is shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
+
+ Returns:
+ Requester
+ """
+ if not isinstance(user_id, UserID):
+ user_id = UserID.from_string(user_id)
+
+ if authenticated_entity is None:
+ authenticated_entity = user_id.to_string()
+
+ return Requester(
+ user_id,
+ access_token_id,
+ is_guest,
+ shadow_banned,
+ device_id,
+ app_service,
+ authenticated_entity,
+ )
+
+
+def get_domain_from_id(string: str) -> str:
+ idx = string.find(":")
+ if idx == -1:
+ raise SynapseError(400, "Invalid ID: %r" % (string,))
+ return string[idx + 1 :]
+
+
+def get_localpart_from_id(string: str) -> str:
+ idx = string.find(":")
+ if idx == -1:
+ raise SynapseError(400, "Invalid ID: %r" % (string,))
+ return string[1:idx]
+
+
+DS = TypeVar("DS", bound="DomainSpecificString")
+
+
+@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True)
+class DomainSpecificString(metaclass=abc.ABCMeta):
+ """Common base class among ID/name strings that have a local part and a
+ domain name, prefixed with a sigil.
+
+ Has the fields:
+
+ 'localpart' : The local part of the name (without the leading sigil)
+ 'domain' : The domain part of the name
+ """
+
+ SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
+
+ localpart: str
+ domain: str
+
+ # Because this is a frozen class, it is deeply immutable.
+ def __copy__(self: DS) -> DS:
+ return self
+
+ def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
+ return self
+
+ @classmethod
+ def from_string(cls: Type[DS], s: str) -> DS:
+ """Parse the string given by 's' into a structure object."""
+ if len(s) < 1 or s[0:1] != cls.SIGIL:
+ raise SynapseError(
+ 400,
+ "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
+ Codes.INVALID_PARAM,
+ )
+
+ parts = s[1:].split(":", 1)
+ if len(parts) != 2:
+ raise SynapseError(
+ 400,
+ "Expected %s of the form '%slocalname:domain'"
+ % (cls.__name__, cls.SIGIL),
+ Codes.INVALID_PARAM,
+ )
+
+ domain = parts[1]
+ # This code will need changing if we want to support multiple domain
+ # names on one HS
+ return cls(localpart=parts[0], domain=domain)
+
+ def to_string(self) -> str:
+ """Return a string encoding the fields of the structure object."""
+ return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
+
+ @classmethod
+ def is_valid(cls: Type[DS], s: str) -> bool:
+ """Parses the input string and attempts to ensure it is valid."""
+ # TODO: this does not reject an empty localpart or an overly-long string.
+ # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar
+ try:
+ obj = cls.from_string(s)
+ # Apply additional validation to the domain. This is only done
+ # during is_valid (and not part of from_string) since it is
+ # possible for invalid data to exist in room-state, etc.
+ parse_and_validate_server_name(obj.domain)
+ return True
+ except Exception:
+ return False
+
+ __repr__ = to_string
+
+
+@attr.s(slots=True, frozen=True, repr=False)
+class UserID(DomainSpecificString):
+ """Structure representing a user ID."""
+
+ SIGIL = "@"
+
+
+@attr.s(slots=True, frozen=True, repr=False)
+class RoomAlias(DomainSpecificString):
+ """Structure representing a room name."""
+
+ SIGIL = "#"
+
+
+@attr.s(slots=True, frozen=True, repr=False)
+class RoomID(DomainSpecificString):
+ """Structure representing a room id."""
+
+ SIGIL = "!"
+
+
+@attr.s(slots=True, frozen=True, repr=False)
+class EventID(DomainSpecificString):
+ """Structure representing an event id."""
+
+ SIGIL = "$"
+
+
+mxid_localpart_allowed_characters = set(
+ "_-./=" + string.ascii_lowercase + string.digits
+)
+
+
+def contains_invalid_mxid_characters(localpart: str) -> bool:
+ """Check for characters not allowed in an mxid or groupid localpart
+
+ Args:
+ localpart: the localpart to be checked
+
+ Returns:
+ True if there are any naughty characters
+ """
+ return any(c not in mxid_localpart_allowed_characters for c in localpart)
+
+
+UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
+
+# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
+# localpart.
+#
+# It works by:
+# * building a string containing the allowed characters (excluding '=')
+# * escaping every special character with a backslash (to stop '-' being interpreted as a
+# range operator)
+# * wrapping it in a '[^...]' regex
+# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
+# bytes rather than strings
+#
+NON_MXID_CHARACTER_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode(
+ "ascii"
+ )
+)
+
+
+def map_username_to_mxid_localpart(
+ username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
+ """Map a username onto a string suitable for a MXID
+
+ This follows the algorithm laid out at
+ https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
+
+ Args:
+ username: username to be mapped
+ case_sensitive: true if TEST and test should be mapped
+ onto different mxids
+
+ Returns:
+ string suitable for a mxid localpart
+ """
+ if not isinstance(username, bytes):
+ username = username.encode("utf-8")
+
+ # first we sort out upper-case characters
+ if case_sensitive:
+
+ def f1(m: Match[bytes]) -> bytes:
+ return b"_" + m.group().lower()
+
+ username = UPPER_CASE_PATTERN.sub(f1, username)
+ else:
+ username = username.lower()
+
+ # then we sort out non-ascii characters by converting to the hex equivalent.
+ def f2(m: Match[bytes]) -> bytes:
+ return b"=%02x" % (m.group()[0],)
+
+ username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
+
+ # we also do the =-escaping to mxids starting with an underscore.
+ username = re.sub(b"^_", b"=5f", username)
+
+ # we should now only have ascii bytes left, so can decode back to a string.
+ return username.decode("ascii")
+
+
+@attr.s(frozen=True, slots=True, order=False)
+class RoomStreamToken:
+ """Tokens are positions between events. The token "s1" comes after event 1.
+
+ s0 s1
+ | |
+ [0] â–¼ [1] â–¼ [2]
+
+ Tokens can either be a point in the live event stream or a cursor going
+ through historic events.
+
+ When traversing the live event stream, events are ordered by
+ `stream_ordering` (when they arrived at the homeserver).
+
+ When traversing historic events, events are first ordered by their `depth`
+ (`topological_ordering` in the event graph) and tie-broken by
+ `stream_ordering` (when the event arrived at the homeserver).
+
+ If you're looking for more info about what a token with all of the
+ underscores means, ex.
+ `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring
+ for `StreamToken` below.
+
+ ---
+
+ Live tokens start with an "s" followed by the `stream_ordering` of the event
+ that comes before the position of the token. Said another way:
+ `stream_ordering` uniquely identifies a persisted event. The live token
+ means "the position just after the event identified by `stream_ordering`".
+ An example token is:
+
+ s2633508
+
+ ---
+
+ Historic tokens start with a "t" followed by the `depth`
+ (`topological_ordering` in the event graph) of the event that comes before
+ the position of the token, followed by "-", followed by the
+ `stream_ordering` of the event that comes before the position of the token.
+ An example token is:
+
+ t426-2633508
+
+ ---
+
+ There is also a third mode for live tokens where the token starts with "m",
+ which is sometimes used when using sharded event persisters. In this case
+ the events stream is considered to be a set of streams (one for each writer)
+ and the token encodes the vector clock of positions of each writer in their
+ respective streams.
+
+ The format of the token in such case is an initial integer min position,
+ followed by the mapping of instance ID to position separated by '.' and '~':
+
+ m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
+
+ The `min_pos` corresponds to the minimum position all writers have persisted
+ up to, and then only writers that are ahead of that position need to be
+ encoded. An example token is:
+
+ m56~2.58~3.59
+
+ Which corresponds to a set of three (or more writers) where instances 2 and
+ 3 (these are instance IDs that can be looked up in the DB to fetch the more
+ commonly used instance names) are at positions 58 and 59 respectively, and
+ all other instances are at position 56.
+
+ Note: The `RoomStreamToken` cannot have both a topological part and an
+ instance map.
+
+ ---
+
+ For caching purposes, `RoomStreamToken`s and by extension, all their
+ attributes, must be hashable.
+ """
+
+ topological: Optional[int] = attr.ib(
+ validator=attr.validators.optional(attr.validators.instance_of(int)),
+ )
+ stream: int = attr.ib(validator=attr.validators.instance_of(int))
+
+ instance_map: "frozendict[str, int]" = attr.ib(
+ factory=frozendict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(frozendict),
+ ),
+ )
+
+ def __attrs_post_init__(self) -> None:
+ """Validates that both `topological` and `instance_map` aren't set."""
+
+ if self.instance_map and self.topological:
+ raise ValueError(
+ "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
+ )
+
+ @classmethod
+ async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
+ return cls(topological=int(parts[0]), stream=int(parts[1]))
+ if string[0] == "m":
+ parts = string[1:].split("~")
+ stream = int(parts[0])
+
+ instance_map = {}
+ for part in parts[1:]:
+ key, value = part.split(".")
+ instance_id = int(key)
+ pos = int(value)
+
+ instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined]
+ instance_map[instance_name] = pos
+
+ return cls(
+ topological=None,
+ stream=stream,
+ instance_map=frozendict(instance_map),
+ )
+ except CancelledError:
+ raise
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
+
+ @classmethod
+ def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
+
+ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+ """Return a new token such that if an event is after both this token and
+ the other token, then its after the returned token too.
+ """
+
+ if self.topological or other.topological:
+ raise Exception("Can't advance topological tokens")
+
+ max_stream = max(self.stream, other.stream)
+
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return RoomStreamToken(None, max_stream, frozendict(instance_map))
+
+ def as_historical_tuple(self) -> Tuple[int, int]:
+ """Returns a tuple of `(topological, stream)` for historical tokens.
+
+ Raises if not an historical token (i.e. doesn't have a topological part).
+ """
+ if self.topological is None:
+ raise Exception(
+ "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
+ )
+
+ return self.topological, self.stream
+
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token.
+
+ This only makes sense for "live" tokens that may have a vector clock
+ component, and so asserts that this is a "live" token.
+ """
+ assert self.topological is None
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
+ async def to_string(self, store: "DataStore") -> str:
+ if self.topological is not None:
+ return "t%d-%d" % (self.topological, self.stream)
+ elif self.instance_map:
+ entries = []
+ for name, pos in self.instance_map.items():
+ instance_id = await store.get_id_for_instance(name)
+ entries.append(f"{instance_id}.{pos}")
+
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ else:
+ return "s%d" % (self.stream,)
+
+
+class StreamKeyType:
+ """Known stream types.
+
+ A stream is a list of entities ordered by an incrementing "stream token".
+ """
+
+ ROOM: Final = "room_key"
+ PRESENCE: Final = "presence_key"
+ TYPING: Final = "typing_key"
+ RECEIPT: Final = "receipt_key"
+ ACCOUNT_DATA: Final = "account_data_key"
+ PUSH_RULES: Final = "push_rules_key"
+ TO_DEVICE: Final = "to_device_key"
+ DEVICE_LIST: Final = "device_list_key"
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StreamToken:
+ """A collection of keys joined together by underscores in the following
+ order and which represent the position in their respective streams.
+
+ ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1`
+ 1. `room_key`: `s2633508` which is a `RoomStreamToken`
+ - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
+ - See the docstring for `RoomStreamToken` for more details.
+ 2. `presence_key`: `17`
+ 3. `typing_key`: `338`
+ 4. `receipt_key`: `6732159`
+ 5. `account_data_key`: `1082514`
+ 6. `push_rules_key`: `541479`
+ 7. `to_device_key`: `274711`
+ 8. `device_list_key`: `265584`
+ 9. `groups_key`: `1` (note that this key is now unused)
+
+ You can see how many of these keys correspond to the various
+ fields in a "/sync" response:
+ ```json
+ {
+ "next_batch": "s12_4_0_1_1_1_1_4_1",
+ "presence": {
+ "events": []
+ },
+ "device_lists": {
+ "changed": []
+ },
+ "rooms": {
+ "join": {
+ "!QrZlfIDQLNLdZHqTnt:hs1": {
+ "timeline": {
+ "events": [],
+ "prev_batch": "s10_4_0_1_1_1_1_4_1",
+ "limited": false
+ },
+ "state": {
+ "events": []
+ },
+ "account_data": {
+ "events": []
+ },
+ "ephemeral": {
+ "events": []
+ }
+ }
+ }
+ }
+ }
+ ```
+
+ ---
+
+ For caching purposes, `StreamToken`s and by extension, all their attributes,
+ must be hashable.
+ """
+
+ room_key: RoomStreamToken = attr.ib(
+ validator=attr.validators.instance_of(RoomStreamToken)
+ )
+ presence_key: int
+ typing_key: int
+ receipt_key: int
+ account_data_key: int
+ push_rules_key: int
+ to_device_key: int
+ device_list_key: int
+ # Note that the groups key is no longer used and may have bogus values.
+ groups_key: int
+
+ _SEPARATOR = "_"
+ START: ClassVar["StreamToken"]
+
+ @classmethod
+ @cancellable
+ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
+ """
+ Creates a RoomStreamToken from its textual representation.
+ """
+ try:
+ keys = string.split(cls._SEPARATOR)
+ while len(keys) < len(attr.fields(cls)):
+ # i.e. old token from before receipt_key
+ keys.append("0")
+ return cls(
+ await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+ )
+ except CancelledError:
+ raise
+ except Exception:
+ raise SynapseError(400, "Invalid stream token")
+
+ async def to_string(self, store: "DataStore") -> str:
+ return self._SEPARATOR.join(
+ [
+ await self.room_key.to_string(store),
+ str(self.presence_key),
+ str(self.typing_key),
+ str(self.receipt_key),
+ str(self.account_data_key),
+ str(self.push_rules_key),
+ str(self.to_device_key),
+ str(self.device_list_key),
+ # Note that the groups key is no longer used, but it is still
+ # serialized so that there will not be confusion in the future
+ # if additional tokens are added.
+ str(self.groups_key),
+ ]
+ )
+
+ @property
+ def room_stream_id(self) -> int:
+ return self.room_key.stream
+
+ def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
+ """Advance the given key in the token to a new value if and only if the
+ new value is after the old value.
+
+ :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
+ """
+ if key == StreamKeyType.ROOM:
+ new_token = self.copy_and_replace(
+ StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value)
+ )
+ return new_token
+
+ new_token = self.copy_and_replace(key, new_value)
+ new_id = int(getattr(new_token, key))
+ old_id = int(getattr(self, key))
+
+ if old_id < new_id:
+ return new_token
+ else:
+ return self
+
+ def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
+ return attr.evolve(self, **{key: new_value})
+
+
+StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PersistedEventPosition:
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
+ instance_name: str
+ stream: int
+
+ def persisted_after(self, token: RoomStreamToken) -> bool:
+ return token.get_stream_pos_for_instance(self.instance_name) < self.stream
+
+ def to_room_stream_token(self) -> RoomStreamToken:
+ """Converts the position to a room stream token such that events
+ persisted in the same room after this position will be after the
+ returned `RoomStreamToken`.
+
+ Note: no guarantees are made about ordering w.r.t. events in other
+ rooms.
+ """
+ # Doing the naive thing satisfies the desired properties described in
+ # the docstring.
+ return RoomStreamToken(None, self.stream)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThirdPartyInstanceID:
+ appservice_id: Optional[str]
+ network_id: Optional[str]
+
+ # Deny iteration because it will bite you if you try to create a singleton
+ # set by:
+ # users = set(user)
+ def __iter__(self) -> NoReturn:
+ raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
+
+ # Because this class is a frozen class, it is deeply immutable.
+ def __copy__(self) -> "ThirdPartyInstanceID":
+ return self
+
+ def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
+ return self
+
+ @classmethod
+ def from_string(cls, s: str) -> "ThirdPartyInstanceID":
+ bits = s.split("|", 2)
+ if len(bits) != 2:
+ raise SynapseError(400, "Invalid ID %r" % (s,))
+
+ return cls(appservice_id=bits[0], network_id=bits[1])
+
+ def to_string(self) -> str:
+ return "%s|%s" % (self.appservice_id, self.network_id)
+
+ __str__ = to_string
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ReadReceipt:
+ """Information about a read-receipt"""
+
+ room_id: str
+ receipt_type: str
+ user_id: str
+ event_ids: List[str]
+ thread_id: Optional[str]
+ data: JsonDict
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class DeviceListUpdates:
+ """
+ An object containing a diff of information regarding other users' device lists, intended for
+ a recipient to carry out device list tracking.
+
+ Attributes:
+ changed: A set of users whose device lists have changed recently.
+ left: A set of users who the recipient no longer needs to track the device lists of.
+ Typically when those users no longer share any end-to-end encryption enabled rooms.
+ """
+
+ # We need to use a factory here, otherwise `set` is not evaluated at
+ # object instantiation, but instead at class definition instantiation.
+ # The latter happening only once, thus always giving you the same sets
+ # across multiple DeviceListUpdates instances.
+ # Also see: don't define mutable default arguments.
+ changed: Set[str] = attr.ib(factory=set)
+ left: Set[str] = attr.ib(factory=set)
+
+ def __bool__(self) -> bool:
+ return bool(self.changed or self.left)
+
+
+def get_verify_key_from_cross_signing_key(
+ key_info: Mapping[str, Any]
+) -> Tuple[str, VerifyKey]:
+ """Get the key ID and signedjson verify key from a cross-signing key dict
+
+ Args:
+ key_info: a cross-signing key dict, which must have a "keys"
+ property that has exactly one item in it
+
+ Returns:
+ the key ID and verify key for the cross-signing key
+ """
+ # make sure that a `keys` field is provided
+ if "keys" not in key_info:
+ raise ValueError("Invalid key")
+ keys = key_info["keys"]
+ # and that it contains exactly one key
+ if len(keys) == 1:
+ key_id, key_data = next(iter(keys.items()))
+ return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
+ else:
+ raise ValueError("Invalid key")
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class UserInfo:
+ """Holds information about a user. Result of get_userinfo_by_id.
+
+ Attributes:
+ user_id: ID of the user.
+ appservice_id: Application service ID that created this user.
+ consent_server_notice_sent: Version of policy documents the user has been sent.
+ consent_version: Version of policy documents the user has consented to.
+ creation_ts: Creation timestamp of the user.
+ is_admin: True if the user is an admin.
+ is_deactivated: True if the user has been deactivated.
+ is_guest: True if the user is a guest user.
+ is_shadow_banned: True if the user has been shadow-banned.
+ user_type: User type (None for normal user, 'support' and 'bot' other options).
+ """
+
+ user_id: UserID
+ appservice_id: Optional[int]
+ consent_server_notice_sent: Optional[str]
+ consent_version: Optional[str]
+ user_type: Optional[str]
+ creation_ts: int
+ is_admin: bool
+ is_deactivated: bool
+ is_guest: bool
+ is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+ user_id: str
+ display_name: Optional[str]
+ avatar_url: Optional[str]
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class RetentionPolicy:
+ min_lifetime: Optional[int] = None
+ max_lifetime: Optional[int] = None
diff --git a/synapse/types/state.py b/synapse/types/state.py
new file mode 100644
index 0000000000..743a4f9217
--- /dev/null
+++ b/synapse/types/state.py
@@ -0,0 +1,585 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import EventTypes
+from synapse.types import MutableStateMap, StateKey, StateMap
+
+if TYPE_CHECKING:
+ from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
+
+
+logger = logging.getLogger(__name__)
+
+# Used for generic functions below
+T = TypeVar("T")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateFilter:
+ """A filter used when querying for state.
+
+ Attributes:
+ types: Map from type to set of state keys (or None). This specifies
+ which state_keys for the given type to fetch from the DB. If None
+ then all events with that type are fetched. If the set is empty
+ then no events with that type are fetched.
+ include_others: Whether to fetch events with types that do not
+ appear in `types`.
+ """
+
+ types: "frozendict[str, Optional[FrozenSet[str]]]"
+ include_others: bool = False
+
+ def __attrs_post_init__(self) -> None:
+ # If `include_others` is set we canonicalise the filter by removing
+ # wildcards from the types dictionary
+ if self.include_others:
+ # this is needed to work around the fact that StateFilter is frozen
+ object.__setattr__(
+ self,
+ "types",
+ frozendict({k: v for k, v in self.types.items() if v is not None}),
+ )
+
+ @staticmethod
+ def all() -> "StateFilter":
+ """Returns a filter that fetches everything.
+
+ Returns:
+ The state filter.
+ """
+ return _ALL_STATE_FILTER
+
+ @staticmethod
+ def none() -> "StateFilter":
+ """Returns a filter that fetches nothing.
+
+ Returns:
+ The new state filter.
+ """
+ return _NONE_STATE_FILTER
+
+ @staticmethod
+ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
+ """Creates a filter that only fetches the given types
+
+ Args:
+ types: A list of type and state keys to fetch. A state_key of None
+ fetches everything for that type
+
+ Returns:
+ The new state filter.
+ """
+ type_dict: Dict[str, Optional[Set[str]]] = {}
+ for typ, s in types:
+ if typ in type_dict:
+ if type_dict[typ] is None:
+ continue
+
+ if s is None:
+ type_dict[typ] = None
+ continue
+
+ type_dict.setdefault(typ, set()).add(s) # type: ignore
+
+ return StateFilter(
+ types=frozendict(
+ (k, frozenset(v) if v is not None else None)
+ for k, v in type_dict.items()
+ )
+ )
+
+ def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
+ """The inverse to `from_types`."""
+ for (event_type, state_keys) in self.types.items():
+ if state_keys is None:
+ yield event_type, None
+ else:
+ for state_key in state_keys:
+ yield event_type, state_key
+
+ @staticmethod
+ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
+ """Creates a filter that returns all non-member events, plus the member
+ events for the given users
+
+ Args:
+ members: Set of user IDs
+
+ Returns:
+ The new state filter
+ """
+ return StateFilter(
+ types=frozendict({EventTypes.Member: frozenset(members)}),
+ include_others=True,
+ )
+
+ @staticmethod
+ def freeze(
+ types: Mapping[str, Optional[Collection[str]]], include_others: bool
+ ) -> "StateFilter":
+ """
+ Returns a (frozen) StateFilter with the same contents as the parameters
+ specified here, which can be made of mutable types.
+ """
+ types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
+ for state_types, state_keys in types.items():
+ if state_keys is not None:
+ types_with_frozen_values[state_types] = frozenset(state_keys)
+ else:
+ types_with_frozen_values[state_types] = None
+
+ return StateFilter(
+ frozendict(types_with_frozen_values), include_others=include_others
+ )
+
+ def return_expanded(self) -> "StateFilter":
+ """Creates a new StateFilter where type wild cards have been removed
+ (except for memberships). The returned filter is a superset of the
+ current one, i.e. anything that passes the current filter will pass
+ the returned filter.
+
+ This helps the caching as the DictionaryCache knows if it has *all* the
+ state, but does not know if it has all of the keys of a particular type,
+ which makes wildcard lookups expensive unless we have a complete cache.
+ Hence, if we are doing a wildcard lookup, populate the cache fully so
+ that we can do an efficient lookup next time.
+
+ Note that since we have two caches, one for membership events and one for
+ other events, we can be a bit more clever than simply returning
+ `StateFilter.all()` if `has_wildcards()` is True.
+
+ We return a StateFilter where:
+ 1. the list of membership events to return is the same
+ 2. if there is a wildcard that matches non-member events we
+ return all non-member events
+
+ Returns:
+ The new state filter.
+ """
+
+ if self.is_full():
+ # If we're going to return everything then there's nothing to do
+ return self
+
+ if not self.has_wildcards():
+ # If there are no wild cards, there's nothing to do
+ return self
+
+ if EventTypes.Member in self.types:
+ get_all_members = self.types[EventTypes.Member] is None
+ else:
+ get_all_members = self.include_others
+
+ has_non_member_wildcard = self.include_others or any(
+ state_keys is None
+ for t, state_keys in self.types.items()
+ if t != EventTypes.Member
+ )
+
+ if not has_non_member_wildcard:
+ # If there are no non-member wild cards we can just return ourselves
+ return self
+
+ if get_all_members:
+ # We want to return everything.
+ return StateFilter.all()
+ elif EventTypes.Member in self.types:
+ # We want to return all non-members, but only particular
+ # memberships
+ return StateFilter(
+ types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
+ include_others=True,
+ )
+ else:
+ # We want to return all non-members
+ return _ALL_NON_MEMBER_STATE_FILTER
+
+ def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
+ """Converts the filter to an SQL clause.
+
+ For example:
+
+ f = StateFilter.from_types([("m.room.create", "")])
+ clause, args = f.make_sql_filter_clause()
+ clause == "(type = ? AND state_key = ?)"
+ args == ['m.room.create', '']
+
+
+ Returns:
+ The SQL string (may be empty) and arguments. An empty SQL string is
+ returned when the filter matches everything (i.e. is "full").
+ """
+
+ where_clause = ""
+ where_args: List[str] = []
+
+ if self.is_full():
+ return where_clause, where_args
+
+ if not self.include_others and not self.types:
+ # i.e. this is an empty filter, so we need to return a clause that
+ # will match nothing
+ return "1 = 2", []
+
+ # First we build up a lost of clauses for each type/state_key combo
+ clauses = []
+ for etype, state_keys in self.types.items():
+ if state_keys is None:
+ clauses.append("(type = ?)")
+ where_args.append(etype)
+ continue
+
+ for state_key in state_keys:
+ clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend((etype, state_key))
+
+ # This will match anything that appears in `self.types`
+ where_clause = " OR ".join(clauses)
+
+ # If we want to include stuff that's not in the types dict then we add
+ # a `OR type NOT IN (...)` clause to the end.
+ if self.include_others:
+ if where_clause:
+ where_clause += " OR "
+
+ where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
+ where_args.extend(self.types)
+
+ return where_clause, where_args
+
+ def max_entries_returned(self) -> Optional[int]:
+ """Returns the maximum number of entries this filter will return if
+ known, otherwise returns None.
+
+ For example a simple state filter asking for `("m.room.create", "")`
+ will return 1, whereas the default state filter will return None.
+
+ This is used to bail out early if the right number of entries have been
+ fetched.
+ """
+ if self.has_wildcards():
+ return None
+
+ return len(self.concrete_types())
+
+ def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+ """Returns the state filtered with by this StateFilter.
+
+ Args:
+ state: The state map to filter
+
+ Returns:
+ The filtered state map.
+ This is a copy, so it's safe to mutate.
+ """
+ if self.is_full():
+ return dict(state_dict)
+
+ filtered_state = {}
+ for k, v in state_dict.items():
+ typ, state_key = k
+ if typ in self.types:
+ state_keys = self.types[typ]
+ if state_keys is None or state_key in state_keys:
+ filtered_state[k] = v
+ elif self.include_others:
+ filtered_state[k] = v
+
+ return filtered_state
+
+ def is_full(self) -> bool:
+ """Whether this filter fetches everything or not
+
+ Returns:
+ True if the filter fetches everything.
+ """
+ return self.include_others and not self.types
+
+ def has_wildcards(self) -> bool:
+ """Whether the filter includes wildcards or is attempting to fetch
+ specific state.
+
+ Returns:
+ True if the filter includes wildcards.
+ """
+
+ return self.include_others or any(
+ state_keys is None for state_keys in self.types.values()
+ )
+
+ def concrete_types(self) -> List[Tuple[str, str]]:
+ """Returns a list of concrete type/state_keys (i.e. not None) that
+ will be fetched. This will be a complete list if `has_wildcards`
+ returns False, but otherwise will be a subset (or even empty).
+
+ Returns:
+ A list of type/state_keys tuples.
+ """
+ return [
+ (t, s)
+ for t, state_keys in self.types.items()
+ if state_keys is not None
+ for s in state_keys
+ ]
+
+ def wildcard_types(self) -> List[str]:
+ """Returns a list of event types which require us to fetch all state keys.
+ This will be empty unless `has_wildcards` returns True.
+
+ Returns:
+ A list of event types.
+ """
+ return [t for t, state_keys in self.types.items() if state_keys is None]
+
+ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
+ """Return the filter split into two: one which assumes it's exclusively
+ matching against member state, and one which assumes it's matching
+ against non member state.
+
+ This is useful due to the returned filters giving correct results for
+ `is_full()`, `has_wildcards()`, etc, when operating against maps that
+ either exclusively contain member events or only contain non-member
+ events. (Which is the case when dealing with the member vs non-member
+ state caches).
+
+ Returns:
+ The member and non member filters
+ """
+
+ if EventTypes.Member in self.types:
+ state_keys = self.types[EventTypes.Member]
+ if state_keys is None:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
+ elif self.include_others:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter.none()
+
+ non_member_filter = StateFilter(
+ types=frozendict(
+ {k: v for k, v in self.types.items() if k != EventTypes.Member}
+ ),
+ include_others=self.include_others,
+ )
+
+ return member_filter, non_member_filter
+
+ def _decompose_into_four_parts(
+ self,
+ ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
+ """
+ Decomposes this state filter into 4 constituent parts, which can be
+ thought of as this:
+ all? - minus_wildcards + plus_wildcards + plus_state_keys
+
+ where
+ * all represents ALL state
+ * minus_wildcards represents entire state types to remove
+ * plus_wildcards represents entire state types to add
+ * plus_state_keys represents individual state keys to add
+
+ See `recompose_from_four_parts` for the other direction of this
+ correspondence.
+ """
+ is_all = self.include_others
+ excluded_types: Set[str] = {t for t in self.types if is_all}
+ wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
+ concrete_keys: Set[StateKey] = set(self.concrete_types())
+
+ return (is_all, excluded_types), (wildcard_types, concrete_keys)
+
+ @staticmethod
+ def _recompose_from_four_parts(
+ all_part: bool,
+ minus_wildcards: Set[str],
+ plus_wildcards: Set[str],
+ plus_state_keys: Set[StateKey],
+ ) -> "StateFilter":
+ """
+ Recomposes a state filter from 4 parts.
+
+ See `decompose_into_four_parts` (the other direction of this
+ correspondence) for descriptions on each of the parts.
+ """
+
+ # {state type -> set of state keys OR None for wildcard}
+ # (The same structure as that of a StateFilter.)
+ new_types: Dict[str, Optional[Set[str]]] = {}
+
+ # if we start with all, insert the excluded statetypes as empty sets
+ # to prevent them from being included
+ if all_part:
+ new_types.update({state_type: set() for state_type in minus_wildcards})
+
+ # insert the plus wildcards
+ new_types.update({state_type: None for state_type in plus_wildcards})
+
+ # insert the specific state keys
+ for state_type, state_key in plus_state_keys:
+ if state_type in new_types:
+ entry = new_types[state_type]
+ if entry is not None:
+ entry.add(state_key)
+ elif not all_part:
+ # don't insert if the entire type is already included by
+ # include_others as this would actually shrink the state allowed
+ # by this filter.
+ new_types[state_type] = {state_key}
+
+ return StateFilter.freeze(new_types, include_others=all_part)
+
+ def approx_difference(self, other: "StateFilter") -> "StateFilter":
+ """
+ Returns a state filter which represents `self - other`.
+
+ This is useful for determining what state remains to be pulled out of the
+ database if we want the state included by `self` but already have the state
+ included by `other`.
+
+ The returned state filter
+ - MUST include all state events that are included by this filter (`self`)
+ unless they are included by `other`;
+ - MUST NOT include state events not included by this filter (`self`); and
+ - MAY be an over-approximation: the returned state filter
+ MAY additionally include some state events from `other`.
+
+ This implementation attempts to return the narrowest such state filter.
+ In the case that `self` contains wildcards for state types where
+ `other` contains specific state keys, an approximation must be made:
+ the returned state filter keeps the wildcard, as state filters are not
+ able to express 'all state keys except some given examples'.
+ e.g.
+ StateFilter(m.room.member -> None (wildcard))
+ minus
+ StateFilter(m.room.member -> {'@wombat:example.org'})
+ is approximated as
+ StateFilter(m.room.member -> None (wildcard))
+ """
+
+ # We first transform self and other into an alternative representation:
+ # - whether or not they include all events to begin with ('all')
+ # - if so, which event types are excluded? ('excludes')
+ # - which entire event types to include ('wildcards')
+ # - which concrete state keys to include ('concrete state keys')
+ (self_all, self_excludes), (
+ self_wildcards,
+ self_concrete_keys,
+ ) = self._decompose_into_four_parts()
+ (other_all, other_excludes), (
+ other_wildcards,
+ other_concrete_keys,
+ ) = other._decompose_into_four_parts()
+
+ # Start with an estimate of the difference based on self
+ new_all = self_all
+ # Wildcards from the other can be added to the exclusion filter
+ new_excludes = self_excludes | other_wildcards
+ # We remove wildcards that appeared as wildcards in the other
+ new_wildcards = self_wildcards - other_wildcards
+ # We filter out the concrete state keys that appear in the other
+ # as wildcards or concrete state keys.
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in self_concrete_keys
+ if state_type not in other_wildcards
+ } - other_concrete_keys
+
+ if other_all:
+ if self_all:
+ # If self starts with all, then we add as wildcards any
+ # types which appear in the other's exclusion filter (but
+ # aren't in the self exclusion filter). This is as the other
+ # filter will return everything BUT the types in its exclusion, so
+ # we need to add those excluded types that also match the self
+ # filter as wildcard types in the new filter.
+ new_wildcards |= other_excludes.difference(self_excludes)
+
+ # If other is an `include_others` then the difference isn't.
+ new_all = False
+ # (We have no need for excludes when we don't start with all, as there
+ # is nothing to exclude.)
+ new_excludes = set()
+
+ # We also filter out all state types that aren't in the exclusion
+ # list of the other.
+ new_wildcards &= other_excludes
+ new_concrete_keys = {
+ (state_type, state_key)
+ for (state_type, state_key) in new_concrete_keys
+ if state_type in other_excludes
+ }
+
+ # Transform our newly-constructed state filter from the alternative
+ # representation back into the normal StateFilter representation.
+ return StateFilter._recompose_from_four_parts(
+ new_all, new_excludes, new_wildcards, new_concrete_keys
+ )
+
+ def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
+ """Check if we need to wait for full state to complete to calculate this state
+
+ If we have a state filter which is completely satisfied even with partial
+ state, then we don't need to await_full_state before we can return it.
+
+ Args:
+ is_mine_id: a callable which confirms if a given state_key matches a mxid
+ of a local user
+ """
+ # if we haven't requested membership events, then it depends on the value of
+ # 'include_others'
+ if EventTypes.Member not in self.types:
+ return self.include_others
+
+ # if we're looking for *all* membership events, then we have to wait
+ member_state_keys = self.types[EventTypes.Member]
+ if member_state_keys is None:
+ return True
+
+ # otherwise, consider whose membership we are looking for. If it's entirely
+ # local users, then we don't need to wait.
+ for state_key in member_state_keys:
+ if not is_mine_id(state_key):
+ # remote user
+ return True
+
+ # local users only
+ return False
+
+
+_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
+_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
|