summary refs log tree commit diff
path: root/synapse/types.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/types.py')
-rw-r--r--synapse/types.py281
1 files changed, 183 insertions, 98 deletions
diff --git a/synapse/types.py b/synapse/types.py

index 7cc523e4f8..bc73e3775d 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -13,11 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import re import string import sys from collections import namedtuple -from typing import Any, Dict, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, +) import attr from signedjson.key import decode_verify_key_bytes @@ -26,6 +37,9 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection @@ -34,15 +48,16 @@ else: T_co = TypeVar("T_co", covariant=True) - class Collection(Iterable[T_co], Container[T_co], Sized): + class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore __slots__ = () # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") -StateMap = Dict[Tuple[str, str], T] - +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] # the type of a JSON-serialisable dict. This could be made stronger, but it will # do for now. @@ -51,7 +66,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -62,6 +85,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -77,6 +101,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -101,13 +126,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -117,6 +148,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -125,7 +157,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): @@ -142,7 +176,12 @@ def get_localpart_from_id(string): return string[1:idx] -class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): +DS = TypeVar("DS", bound="DomainSpecificString") + + +class DomainSpecificString( + namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta +): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -152,6 +191,8 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 'domain' : The domain part of the name """ + SIGIL = abc.abstractproperty() # type: str # type: ignore + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) @@ -167,7 +208,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s: str): + def from_string(cls: Type[DS], s: str) -> DS: """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( @@ -191,12 +232,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom # names on one HS return cls(localpart=parts[0], domain=domain) - def to_string(self): + def to_string(self) -> str: """Return a string encoding the fields of the structure object.""" return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) @classmethod - def is_valid(cls, s): + def is_valid(cls: Type[DS], s: str) -> bool: try: cls.from_string(s) return True @@ -236,8 +277,9 @@ class GroupID(DomainSpecificString): SIGIL = "+" @classmethod - def from_string(cls, s): - group_id = super(GroupID, cls).from_string(s) + def from_string(cls: Type[DS], s: str) -> DS: + group_id = super().from_string(s) # type: DS # type: ignore + if not group_id.localpart: raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) @@ -347,86 +389,8 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -class StreamToken( - namedtuple( - "Token", - ( - "room_key", - "presence_key", - "typing_key", - "receipt_key", - "account_data_key", - "push_rules_key", - "to_device_key", - "device_list_key", - "groups_key", - ), - ) -): - _SEPARATOR = "_" - START = None # type: StreamToken - - @classmethod - def from_string(cls, string): - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(cls._fields): - # i.e. old token from before receipt_key - keys.append("0") - return cls(*keys) - except Exception: - raise SynapseError(400, "Invalid Token") - - def to_string(self): - return self._SEPARATOR.join([str(k) for k in self]) - - @property - def room_stream_id(self): - # TODO(markjh): Awful hack to work around hacks in the presence tests - # which assume that the keys are integers. - if type(self.room_key) is int: - return self.room_key - else: - return int(self.room_key[1:].split("-")[-1]) - - def is_after(self, other): - """Does this token contain events that the other doesn't?""" - return ( - (other.room_stream_id < self.room_stream_id) - or (int(other.presence_key) < int(self.presence_key)) - or (int(other.typing_key) < int(self.typing_key)) - or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.account_data_key) < int(self.account_data_key)) - or (int(other.push_rules_key) < int(self.push_rules_key)) - or (int(other.to_device_key) < int(self.to_device_key)) - or (int(other.device_list_key) < int(self.device_list_key)) - or (int(other.groups_key) < int(self.groups_key)) - ) - - def copy_and_advance(self, key, new_value): - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - """ - new_token = self.copy_and_replace(key, new_value) - if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key, new_value): - return self._replace(**{key: new_value}) - - -StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))) - - -class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): +@attr.s(frozen=True, slots=True) +class RoomStreamToken: """Tokens are positions between events. The token "s1" comes after event 1. s0 s1 @@ -449,10 +413,14 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): followed by the "stream_ordering" id of the event it comes after. """ - __slots__ = [] # type: list + topological = attr.ib( + type=Optional[int], + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string): + async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -464,7 +432,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): raise SynapseError(400, "Invalid token %r" % (string,)) @classmethod - def parse_stream_token(cls, string): + def parse_stream_token(cls, string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -472,13 +440,130 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): pass raise SynapseError(400, "Invalid token %r" % (string,)) - def __str__(self): + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + return RoomStreamToken(None, max_stream) + + def as_tuple(self) -> Tuple[Optional[int], int]: + return (self.topological, self.stream) + + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) +@attr.s(slots=True, frozen=True) +class StreamToken: + room_key = attr.ib( + type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key = attr.ib(type=int) + typing_key = attr.ib(type=int) + receipt_key = attr.ib(type=int) + account_data_key = attr.ib(type=int) + push_rules_key = attr.ib(type=int) + to_device_key = attr.ib(type=int) + device_list_key = attr.ib(type=int) + groups_key = attr.ib(type=int) + + _SEPARATOR = "_" + START = None # type: StreamToken + + @classmethod + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except Exception: + raise SynapseError(400, "Invalid Token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self): + return self.room_key.stream + + def copy_and_advance(self, key, new_value) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + """ + if key == "room_key": + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key, new_value) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name = attr.ib(type=str) + stream = attr.ib(type=int) + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.stream < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarentees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) ): @@ -514,7 +599,7 @@ class ThirdPartyInstanceID( @attr.s(slots=True) -class ReadReceipt(object): +class ReadReceipt: """Information about a read-receipt""" room_id = attr.ib()