diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 663991a9b6..1782428048 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import division
+
import itertools
import logging
from collections import namedtuple
@@ -103,7 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -142,7 +144,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : Dict from event_id to event.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -152,13 +154,32 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
- def _get_events(
+ def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
+ """Get events from the database and return in a list in the same order
+ as given by `event_ids` arg.
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[list[EventBase]]: List of events fetched from the database. The
+ events are in the same order as `event_ids` arg.
+
+ Note that the returned list may be smaller than the list of event
+ IDs if not all events could be fetched.
+ """
+
if not event_ids:
defer.returnValue([])
@@ -202,21 +223,22 @@ class EventsWorkerStore(SQLBaseStore):
#
# The problem is that we end up at this point when an event
# which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check the
- # redaction before it can cache the redacted event. So obviously,
- # calling get_event to get the redacted event out of the database
- # gives us an infinite loop.
+ # _enqueue_events, because _enqueue_events needs to check
+ # the redaction before it can cache the redacted event. So
+ # obviously, calling get_event to get the redacted event out
+ # of the database gives us an infinite loop.
#
- # For now (quick hack to fix during 0.99 release cycle), we just
- # go and fetch the relevant row from the db, but it would be nice
- # to think about how we can cache this rather than hit the db
- # every time we access a redaction event.
+ # For now (quick hack to fix during 0.99 release cycle), we
+ # just go and fetch the relevant row from the db, but it
+ # would be nice to think about how we can cache this rather
+ # than hit the db every time we access a redaction event.
#
# One thought on how to do this:
- # 1. split _get_events up so that it is divided into (a) get the
- # rawish event from the db/cache, (b) do the redaction/rejection
- # filtering
- # 2. have _get_event_from_row just call the first half of that
+ # 1. split get_events_as_list up so that it is divided into
+ # (a) get the rawish event from the db/cache, (b) do the
+ # redaction/rejection filtering
+ # 2. have _get_event_from_row just call the first half of
+ # that
orig_sender = yield self._simple_select_one_onecol(
table="events",
@@ -590,4 +612,79 @@ class EventsWorkerStore(SQLBaseStore):
return res
- return self.runInteraction("get_rejection_reasons", f)
+ return self.runInteraction("get_seen_events_with_rejections", f)
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_total_state_event_counts.
+ """
+ # We join against the events table as that has an index on room_id
+ sql = """
+ SELECT COUNT(*) FROM state_events
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id=?
+ """
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn, room_id
+ )
+
+ def _get_current_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_current_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_current_state_event_counts(self, room_id):
+ """
+ Gets the current number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_current_state_event_counts",
+ self._get_current_state_event_counts_txn, room_id
+ )
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, room_id):
+ """
+ Get a rough approximation of the complexity of the room. This is used by
+ remote servers to decide whether they wish to join the room or not.
+ Higher complexity value indicates that being in the room will consume
+ more resources.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict[str:int]] of complexity version to complexity.
+ """
+ state_events = yield self.get_current_state_event_counts(room_id)
+
+ # Call this one "v1", so we can introduce new ones as we want to develop
+ # it.
+ complexity_v1 = round(state_events / 500, 2)
+
+ defer.returnValue({"v1": complexity_v1})
|