diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index b371574ece..f0e29e9836 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +17,7 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.receipts import ReceiptsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.receipts import ReceiptsWorkerStore
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedReceiptsStore(BaseSlavedStore):
+class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedReceiptsStore, self).__init__(db_conn, hs)
-
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
- self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
- )
-
- get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
- get_linearized_receipts_for_room = (
- ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
- )
- _get_linearized_receipts_for_rooms = (
- ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
- )
- get_last_receipt_event_id_for_user = (
- ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
- )
-
- get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
- get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+ super(SlavedReceiptsStore, self).__init__(db_conn, hs)
- get_linearized_receipts_for_rooms = (
- DataStore.get_linearized_receipts_for_rooms.__func__
- )
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3f8d4b9c22..83471b3173 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -58,23 +58,13 @@ class MediaStorage(object):
Returns:
Deferred[str]: the file path written to in the primary media store
"""
- path = self._file_info_to_path(file_info)
- fname = os.path.join(self.local_media_directory, path)
- dirname = os.path.dirname(fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
- # Write to the main repository
- yield make_deferred_yieldable(threads.deferToThread(
- _write_file_synchronously, source, fname,
- ))
-
- # Tell the storage providers about the new file. They'll decide
- # if they should upload it and whether to do so synchronously
- # or not.
- for provider in self.storage_providers:
- yield provider.store_file(path, file_info)
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, f,
+ ))
+ yield finish_cb()
defer.returnValue(fname)
@@ -240,21 +230,16 @@ class MediaStorage(object):
)
-def _write_file_synchronously(source, fname):
- """Write `source` to the path `fname` synchronously. Should be called
+def _write_file_synchronously(source, dest):
+ """Write `source` to the file like `dest` synchronously. Should be called
from a thread.
Args:
- source: A file like object to be written
- fname (str): Path to write to
+ source: A file like object that's to be written
+ dest: A file like object to be written to
"""
- dirname = os.path.dirname(fname)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
-
source.seek(0) # Ensure we read from the start of the file
- with open(fname, "wb") as f:
- shutil.copyfileobj(source, f)
+ shutil.copyfileobj(source, dest)
class FileResponder(Responder):
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f8fbd02ceb..e1c4fe086e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -104,9 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
- )
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 12b3cc7f5f..40530632c6 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +15,13 @@
# limitations under the License.
from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
+import abc
import logging
import ujson as json
@@ -26,39 +29,36 @@ import ujson as json
logger = logging.getLogger(__name__)
-class ReceiptsStore(SQLBaseStore):
+class ReceiptsWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
def __init__(self, db_conn, hs):
- super(ReceiptsStore, self).__init__(db_conn, hs)
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
- if receipt_type != "m.read":
- return
-
- # Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get(
- room_id, None, update_metrics=False,
- )
-
- if res:
- if isinstance(res, defer.Deferred) and res.called:
- res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
-
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -270,9 +270,62 @@ class ReceiptsStore(SQLBaseStore):
}
defer.returnValue(results)
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ args = [last_id, current_id]
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns an ObservableDeferred
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
+
+ if res:
+ if isinstance(res, defer.Deferred) and res.called:
+ res = res.result
+ if user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
txn.call_after(
@@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
-
- return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_receipts", get_all_updated_receipts_txn
- )
|