diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index c8056b0c0c..444eb7b7f4 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
import abc
import logging
import re
+from typing import Dict, List, Tuple
from six import raise_from
from six.moves import urllib
@@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
- NAME = abc.abstractproperty()
- PATH_ARGS = abc.abstractproperty()
-
+ NAME = abc.abstractproperty() # type: str # type: ignore
+ PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
- headers = {}
+ headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
@@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler
+ handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index b91a528245..704282c800 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import Dict, Optional
import six
@@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
- )
+ ) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None
@@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
- self._cache_id_gen.advance(token)
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index f552e7c972..ad8f0c15a9 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
self._presence_on_startup = self._get_active_presence(db_conn)
- self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+ self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index bbcb84646c..aa7fd90e26 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
"""
import logging
-from typing import Dict
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
@@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
)
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
RemovePusherCommand,
@@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..cbb36b9acf 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,15 +20,16 @@ allowed to be sent by which side.
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
+ NAME = None # type: str
def __init__(self, data):
self.data = data
@@ -386,25 +387,24 @@ class UserIpCommand(Command):
)
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+) # type: Tuple[Type[Command], ...]
+
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index afaf002fe6..db0353c996 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,6 +53,7 @@ import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -65,13 +66,11 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
@@ -82,6 +81,10 @@ from .commands import (
SyncCommand,
UserSyncCommand,
)
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
from .streams import STREAMS_MAP
connection_close_counter = Counter(
@@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
@@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index d1e98428bc..cbfdaf5773 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,6 +17,7 @@
import logging
import random
+from typing import List
from six import itervalues
@@ -79,7 +80,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8512923eae..4ab0334fc1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import namedtuple
+from typing import Any
from twisted.internet import defer
@@ -104,8 +104,9 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called.
"""
- NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
+ NAME = None # type: str # The name of the stream
+ # The type of the row. Used by the default impl of parse_row.
+ ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@@ -231,8 +232,8 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token
- self.update_function = store.get_all_new_backfill_event_rows
+ self.current_token = store.get_current_backfill_token # type: ignore
+ self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -246,8 +247,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token
- self.update_function = presence_handler.get_all_presence_updates
+ self.current_token = store.get_current_presence_token # type: ignore
+ self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -260,8 +261,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token
- self.update_function = typing_handler.get_all_typing_updates
+ self.current_token = typing_handler.get_current_token # type: ignore
+ self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@@ -273,8 +274,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_receipt_stream_id
- self.update_function = store.get_all_updated_receipts
+ self.current_token = store.get_max_receipt_stream_id # type: ignore
+ self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -310,8 +311,8 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token
- self.update_function = store.get_all_updated_pushers_rows
+ self.current_token = store.get_pushers_stream_token # type: ignore
+ self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -327,8 +328,8 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_cache_stream_token
- self.update_function = store.get_all_updated_caches
+ self.current_token = store.get_cache_stream_token # type: ignore
+ self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -343,8 +344,8 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_public_room_stream_id
- self.update_function = store.get_all_new_public_rooms
+ self.current_token = store.get_current_public_room_stream_id # type: ignore
+ self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -360,8 +361,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_device_list_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -376,8 +377,8 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_to_device_stream_token
- self.update_function = store.get_all_new_device_messages
+ self.current_token = store.get_to_device_stream_token # type: ignore
+ self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,8 +393,8 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_account_data_stream_id
- self.update_function = store.get_all_updated_tags
+ self.current_token = store.get_max_account_data_stream_id # type: ignore
+ self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -408,7 +409,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.current_token = self.store.get_max_account_data_stream_id
+ self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
@@ -434,8 +435,8 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_group_stream_token
- self.update_function = store.get_all_groups_changes
+ self.current_token = store.get_group_stream_token # type: ignore
+ self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -451,7 +452,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_user_signature_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..0843e5aa90 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import heapq
+from typing import Tuple, Type
import attr
@@ -63,7 +65,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = None # type: str
@classmethod
def from_data(cls, data):
@@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
-TypeToRow = {
- Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+) # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
@@ -112,7 +118,7 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token
+ self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -37,7 +37,7 @@ class FederationStream(Stream):
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)
|