diff --git a/synapse/types.py b/synapse/types.py
index 7cc523e4f8..bc73e3775d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -13,11 +13,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Tuple, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -26,6 +37,9 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
+
# define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
@@ -34,15 +48,16 @@ else:
T_co = TypeVar("T_co", covariant=True)
- class Collection(Iterable[T_co], Container[T_co], Sized):
+ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore
__slots__ = ()
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
-StateMap = Dict[Tuple[str, str], T]
-
+StateKey = Tuple[str, str]
+StateMap = Mapping[StateKey, T]
+MutableStateMap = MutableMapping[StateKey, T]
# the type of a JSON-serialisable dict. This could be made stronger, but it will
# do for now.
@@ -51,7 +66,15 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
- "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ "Requester",
+ [
+ "user",
+ "access_token_id",
+ "is_guest",
+ "shadow_banned",
+ "device_id",
+ "app_service",
+ ],
)
):
"""
@@ -62,6 +85,7 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -77,6 +101,7 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
+ "shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -101,13 +126,19 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
+ shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
- user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+ user_id,
+ access_token_id=None,
+ is_guest=False,
+ shadow_banned=False,
+ device_id=None,
+ app_service=None,
):
"""
Create a new ``Requester`` object
@@ -117,6 +148,7 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -125,7 +157,9 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
- return Requester(user_id, access_token_id, is_guest, device_id, app_service)
+ return Requester(
+ user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ )
def get_domain_from_id(string):
@@ -142,7 +176,12 @@ def get_localpart_from_id(string):
return string[1:idx]
-class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
+DS = TypeVar("DS", bound="DomainSpecificString")
+
+
+class DomainSpecificString(
+ namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta
+):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -152,6 +191,8 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
'domain' : The domain part of the name
"""
+ SIGIL = abc.abstractproperty() # type: str # type: ignore
+
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
@@ -167,7 +208,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self
@classmethod
- def from_string(cls, s: str):
+ def from_string(cls: Type[DS], s: str) -> DS:
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(
@@ -191,12 +232,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
# names on one HS
return cls(localpart=parts[0], domain=domain)
- def to_string(self):
+ def to_string(self) -> str:
"""Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
- def is_valid(cls, s):
+ def is_valid(cls: Type[DS], s: str) -> bool:
try:
cls.from_string(s)
return True
@@ -236,8 +277,9 @@ class GroupID(DomainSpecificString):
SIGIL = "+"
@classmethod
- def from_string(cls, s):
- group_id = super(GroupID, cls).from_string(s)
+ def from_string(cls: Type[DS], s: str) -> DS:
+ group_id = super().from_string(s) # type: DS # type: ignore
+
if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
@@ -347,86 +389,8 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-class StreamToken(
- namedtuple(
- "Token",
- (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ),
- )
-):
- _SEPARATOR = "_"
- START = None # type: StreamToken
-
- @classmethod
- def from_string(cls, string):
- try:
- keys = string.split(cls._SEPARATOR)
- while len(keys) < len(cls._fields):
- # i.e. old token from before receipt_key
- keys.append("0")
- return cls(*keys)
- except Exception:
- raise SynapseError(400, "Invalid Token")
-
- def to_string(self):
- return self._SEPARATOR.join([str(k) for k in self])
-
- @property
- def room_stream_id(self):
- # TODO(markjh): Awful hack to work around hacks in the presence tests
- # which assume that the keys are integers.
- if type(self.room_key) is int:
- return self.room_key
- else:
- return int(self.room_key[1:].split("-")[-1])
-
- def is_after(self, other):
- """Does this token contain events that the other doesn't?"""
- return (
- (other.room_stream_id < self.room_stream_id)
- or (int(other.presence_key) < int(self.presence_key))
- or (int(other.typing_key) < int(self.typing_key))
- or (int(other.receipt_key) < int(self.receipt_key))
- or (int(other.account_data_key) < int(self.account_data_key))
- or (int(other.push_rules_key) < int(self.push_rules_key))
- or (int(other.to_device_key) < int(self.to_device_key))
- or (int(other.device_list_key) < int(self.device_list_key))
- or (int(other.groups_key) < int(self.groups_key))
- )
-
- def copy_and_advance(self, key, new_value):
- """Advance the given key in the token to a new value if and only if the
- new value is after the old value.
- """
- new_token = self.copy_and_replace(key, new_value)
- if key == "room_key":
- new_id = new_token.room_stream_id
- old_id = self.room_stream_id
- else:
- new_id = int(getattr(new_token, key))
- old_id = int(getattr(self, key))
- if old_id < new_id:
- return new_token
- else:
- return self
-
- def copy_and_replace(self, key, new_value):
- return self._replace(**{key: new_value})
-
-
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
-
-
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
@@ -449,10 +413,14 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after.
"""
- __slots__ = [] # type: list
+ topological = attr.ib(
+ type=Optional[int],
+ validator=attr.validators.optional(attr.validators.instance_of(int)),
+ )
+ stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
@classmethod
- def parse(cls, string):
+ async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
@@ -464,7 +432,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
- def parse_stream_token(cls, string):
+ def parse_stream_token(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
@@ -472,13 +440,130 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
pass
raise SynapseError(400, "Invalid token %r" % (string,))
- def __str__(self):
+ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+ """Return a new token such that if an event is after both this token and
+ the other token, then its after the returned token too.
+ """
+
+ if self.topological or other.topological:
+ raise Exception("Can't advance topological tokens")
+
+ max_stream = max(self.stream, other.stream)
+
+ return RoomStreamToken(None, max_stream)
+
+ def as_tuple(self) -> Tuple[Optional[int], int]:
+ return (self.topological, self.stream)
+
+ async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+ room_key = attr.ib(
+ type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
+ )
+ presence_key = attr.ib(type=int)
+ typing_key = attr.ib(type=int)
+ receipt_key = attr.ib(type=int)
+ account_data_key = attr.ib(type=int)
+ push_rules_key = attr.ib(type=int)
+ to_device_key = attr.ib(type=int)
+ device_list_key = attr.ib(type=int)
+ groups_key = attr.ib(type=int)
+
+ _SEPARATOR = "_"
+ START = None # type: StreamToken
+
+ @classmethod
+ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
+ try:
+ keys = string.split(cls._SEPARATOR)
+ while len(keys) < len(attr.fields(cls)):
+ # i.e. old token from before receipt_key
+ keys.append("0")
+ return cls(
+ await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
+ )
+ except Exception:
+ raise SynapseError(400, "Invalid Token")
+
+ async def to_string(self, store: "DataStore") -> str:
+ return self._SEPARATOR.join(
+ [
+ await self.room_key.to_string(store),
+ str(self.presence_key),
+ str(self.typing_key),
+ str(self.receipt_key),
+ str(self.account_data_key),
+ str(self.push_rules_key),
+ str(self.to_device_key),
+ str(self.device_list_key),
+ str(self.groups_key),
+ ]
+ )
+
+ @property
+ def room_stream_id(self):
+ return self.room_key.stream
+
+ def copy_and_advance(self, key, new_value) -> "StreamToken":
+ """Advance the given key in the token to a new value if and only if the
+ new value is after the old value.
+ """
+ if key == "room_key":
+ new_token = self.copy_and_replace(
+ "room_key", self.room_key.copy_and_advance(new_value)
+ )
+ return new_token
+
+ new_token = self.copy_and_replace(key, new_value)
+ new_id = int(getattr(new_token, key))
+ old_id = int(getattr(self, key))
+
+ if old_id < new_id:
+ return new_token
+ else:
+ return self
+
+ def copy_and_replace(self, key, new_value) -> "StreamToken":
+ return attr.evolve(self, **{key: new_value})
+
+
+StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
+
+
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
+ instance_name = attr.ib(type=str)
+ stream = attr.ib(type=int)
+
+ def persisted_after(self, token: RoomStreamToken) -> bool:
+ return token.stream < self.stream
+
+ def to_room_stream_token(self) -> RoomStreamToken:
+ """Converts the position to a room stream token such that events
+ persisted in the same room after this position will be after the
+ returned `RoomStreamToken`.
+
+ Note: no guarentees are made about ordering w.r.t. events in other
+ rooms.
+ """
+ # Doing the naive thing satisfies the desired properties described in
+ # the docstring.
+ return RoomStreamToken(None, self.stream)
+
+
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
@@ -514,7 +599,7 @@ class ThirdPartyInstanceID(
@attr.s(slots=True)
-class ReadReceipt(object):
+class ReadReceipt:
"""Information about a read-receipt"""
room_id = attr.ib()
|