diff --git a/synapse/types.py b/synapse/types.py
index acf60baddc..f7de48f148 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -13,11 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Tuple, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -29,19 +30,20 @@ from synapse.api.errors import Codes, SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Sized, Iterable, Container
+ from typing import Container, Iterable, Sized
T_co = TypeVar("T_co", covariant=True)
- class Collection(Iterable[T_co], Container[T_co], Sized):
+ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore
__slots__ = ()
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
-StateMap = Dict[Tuple[str, str], T]
-
+StateKey = Tuple[str, str]
+StateMap = Mapping[StateKey, T]
+MutableStateMap = MutableMapping[StateKey, T]
# the type of a JSON-serialisable dict. This could be made stronger, but it will
# do for now.
@@ -50,7 +52,15 @@ JsonDict = Dict[str, Any]
class Requester(
namedtuple(
- "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ "Requester",
+ [
+ "user",
+ "access_token_id",
+ "is_guest",
+ "shadow_banned",
+ "device_id",
+ "app_service",
+ ],
)
):
"""
@@ -61,6 +71,7 @@ class Requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
"""
@@ -76,6 +87,7 @@ class Requester(
"user_id": self.user.to_string(),
"access_token_id": self.access_token_id,
"is_guest": self.is_guest,
+ "shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
}
@@ -100,13 +112,19 @@ class Requester(
user=UserID.from_string(input["user_id"]),
access_token_id=input["access_token_id"],
is_guest=input["is_guest"],
+ shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
)
def create_requester(
- user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+ user_id,
+ access_token_id=None,
+ is_guest=False,
+ shadow_banned=False,
+ device_id=None,
+ app_service=None,
):
"""
Create a new ``Requester`` object
@@ -116,6 +134,7 @@ def create_requester(
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
+ shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
@@ -124,7 +143,9 @@ def create_requester(
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
- return Requester(user_id, access_token_id, is_guest, device_id, app_service)
+ return Requester(
+ user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ )
def get_domain_from_id(string):
@@ -141,6 +162,9 @@ def get_localpart_from_id(string):
return string[1:idx]
+DS = TypeVar("DS", bound="DomainSpecificString")
+
+
class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -151,6 +175,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
'domain' : The domain part of the name
"""
+ __metaclass__ = abc.ABCMeta
+
+ SIGIL = abc.abstractproperty() # type: str # type: ignore
+
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
@@ -166,7 +194,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self
@classmethod
- def from_string(cls, s: str):
+ def from_string(cls: Type[DS], s: str) -> DS:
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(
@@ -190,12 +218,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
# names on one HS
return cls(localpart=parts[0], domain=domain)
- def to_string(self):
+ def to_string(self) -> str:
"""Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
- def is_valid(cls, s):
+ def is_valid(cls: Type[DS], s: str) -> bool:
try:
cls.from_string(s)
return True
@@ -235,8 +263,9 @@ class GroupID(DomainSpecificString):
SIGIL = "+"
@classmethod
- def from_string(cls, s):
- group_id = super(GroupID, cls).from_string(s)
+ def from_string(cls: Type[DS], s: str) -> DS:
+ group_id = super().from_string(s) # type: DS # type: ignore
+
if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
@@ -500,7 +529,7 @@ class ThirdPartyInstanceID(
@attr.s(slots=True)
-class ReadReceipt(object):
+class ReadReceipt:
"""Information about a read-receipt"""
room_id = attr.ib()
|