diff --git a/CHANGES.rst b/CHANGES.rst
index f1529e79bd..80518b7bae 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,45 @@
+Changes in synapse v0.24.0 (2017-10-23)
+=======================================
+
+No changes since v0.24.0-rc1
+
+
+Changes in synapse v0.24.0-rc1 (2017-10-19)
+===========================================
+
+Features:
+
+* Add Group Server (PR #2352, #2363, #2374, #2377, #2378, #2382, #2410, #2426,
+ #2430, #2454, #2471, #2472, #2544)
+* Add support for channel notifications (PR #2501)
+* Add basic implementation of backup media store (PR #2538)
+* Add config option to auto-join new users to rooms (PR #2545)
+
+
+Changes:
+
+* Make the spam checker a module (PR #2474)
+* Delete expired url cache data (PR #2478)
+* Ignore incoming events for rooms that we have left (PR #2490)
+* Allow spam checker to reject invites too (PR #2492)
+* Add room creation checks to spam checker (PR #2495)
+* Spam checking: add the invitee to user_may_invite (PR #2502)
+* Process events from federation for different rooms in parallel (PR #2520)
+* Allow error strings from spam checker (PR #2531)
+* Improve error handling for missing files in config (PR #2551)
+
+
+Bug fixes:
+
+* Fix handling SERVFAILs when doing AAAA lookups for federation (PR #2477)
+* Fix incompatibility with newer versions of ujson (PR #2483) Thanks to
+ @jeremycline!
+* Fix notification keywords that start/end with non-word chars (PR #2500)
+* Fix stack overflow and logcontexts from linearizer (PR #2532)
+* Fix 500 error when fields missing from power_levels event (PR #2552)
+* Fix 500 error when we get an error handling a PDU (PR #2553)
+
+
Changes in synapse v0.23.1 (2017-10-02)
=======================================
diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py
index 15c19834fc..0b15bd8912 100644
--- a/docs/sphinx/conf.py
+++ b/docs/sphinx/conf.py
@@ -50,7 +50,7 @@ master_doc = 'index'
# General information about the project.
project = u'Synapse'
-copyright = u'2014, TNG'
+copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index bc167b59af..dc7fe940e8 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -376,10 +376,13 @@ class Porter(object):
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
)
- rows_dict = [
- dict(zip(headers, row))
- for row in rows
- ]
+ rows_dict = []
+ for row in rows:
+ d = dict(zip(headers, row))
+ if "\0" in d['value']:
+ logger.warn('dropping search row %s', d)
+ else:
+ rows_dict.append(d)
txn.executemany(sql, [
(
diff --git a/synapse/__init__.py b/synapse/__init__.py
index bee4aba625..c867d1cfd8 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.23.1"
+__version__ = "0.24.0"
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 80e4ba5336..576ac6fb7e 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -69,6 +70,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
+ SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedClientIpStore,
@@ -403,6 +405,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows],
+ )
def start(config_options):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1ab5593c6e..fa105bce72 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -82,21 +82,37 @@ class Config(object):
return os.path.abspath(file_path) if file_path else file_path
@classmethod
+ def path_exists(cls, file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+ @classmethod
def check_file(cls, file_path, config_name):
if file_path is None:
raise ConfigError(
"Missing config for %s."
- " You must specify a path for the config file. You can "
- "do this with the -c or --config-path option. "
- "Adding --generate-config along with --server-name "
- "<server name> will generate a config file at the given path."
% (config_name,)
)
- if not os.path.exists(file_path):
+ try:
+ os.stat(file_path)
+ except OSError as e:
raise ConfigError(
- "File %s config for %s doesn't exist."
- " Try running again with --generate-config"
- % (file_path, config_name,)
+ "Error accessing file '%s' (config for %s): %s"
+ % (file_path, config_name, e.strerror)
)
return cls.abspath(file_path)
@@ -248,7 +264,7 @@ class Config(object):
" -c CONFIG-FILE\""
)
(config_path,) = config_files
- if not os.path.exists(config_path):
+ if not cls.path_exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
@@ -261,7 +277,7 @@ class Config(object):
"Must specify a server_name to a generate config for."
" Pass -H server.name."
)
- if not os.path.exists(config_dir_path):
+ if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
new file mode 100644
index 0000000000..997fa2881f
--- /dev/null
+++ b/synapse/config/groups.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class GroupsConfig(Config):
+ def read_config(self, config):
+ self.enable_group_creation = config.get("enable_group_creation", False)
+ self.group_creation_prefix = config.get("group_creation_prefix", "")
+
+ def default_config(self, **kwargs):
+ return """\
+ # Whether to allow non server admins to create groups on this server
+ enable_group_creation: false
+
+ # If enabled, non server admins can only create groups with local parts
+ # starting with this prefix
+ # group_creation_prefix: "unofficial/"
+ """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b22cacf8dc..05e242aef6 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -34,6 +34,8 @@ from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig
from .workers import WorkerConfig
from .push import PushConfig
+from .spam_checker import SpamCheckerConfig
+from .groups import GroupsConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -41,7 +43,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
+ WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+ SpamCheckerConfig, GroupsConfig,):
pass
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 6ee643793e..4b8fc063d0 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -118,10 +118,9 @@ class KeyConfig(Config):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return read_signing_keys(signing_keys.splitlines(True))
- except Exception:
+ except Exception as e:
raise ConfigError(
- "Error reading signing_key."
- " Try running again with --generate-config"
+ "Error reading signing_key: %s" % (str(e))
)
def read_old_signing_keys(self, old_signing_keys):
@@ -141,7 +140,8 @@ class KeyConfig(Config):
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
- if not os.path.exists(signing_key_path):
+
+ if not self.path_exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83762d089a..90824cab7f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -15,13 +15,15 @@
from ._base import Config, ConfigError
-import importlib
+from synapse.util.module_loader import load_module
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
+ provider_config = None
+
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
@@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider
provider_class = LdapAuthProvider
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
else:
- # We need to import the module, and then pick the class out of
- # that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
- module = importlib.import_module(module)
- provider_class = getattr(module, clz)
+ (provider_class, provider_config) = load_module(provider)
- try:
- provider_config = provider_class.parse_config(provider["config"])
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index f7e03c4cde..ef917fc9f2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -41,6 +41,8 @@ class RegistrationConfig(Config):
self.allow_guest_access and config.get("invite_3pid_guest", False)
)
+ self.auto_join_rooms = config.get("auto_join_rooms", [])
+
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -70,6 +72,11 @@ class RegistrationConfig(Config):
- matrix.org
- vector.im
- riot.im
+
+ # Users who register on this homeserver will automatically be joined
+ # to these rooms
+ #auto_join_rooms:
+ # - "#example:example.com"
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 2c6f57168e..6baa474931 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
+
self.media_store_path = self.ensure_directory(config["media_store_path"])
+
+ self.backup_media_store_path = config.get("backup_media_store_path")
+ if self.backup_media_store_path:
+ self.backup_media_store_path = self.ensure_directory(
+ self.backup_media_store_path
+ )
+
+ self.synchronous_backup_media_store = config.get(
+ "synchronous_backup_media_store", False
+ )
+
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
+ # A secondary directory where uploaded images and attachments are
+ # stored as a backup.
+ # backup_media_store_path: "%(media_store)s"
+
+ # Whether to wait for successful write to backup media store before
+ # returning successfully.
+ # synchronous_backup_media_store: false
+
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
new file mode 100644
index 0000000000..3fec42bdb0
--- /dev/null
+++ b/synapse/config/spam_checker.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class SpamCheckerConfig(Config):
+ def read_config(self, config):
+ self.spam_checker = None
+
+ provider = config.get("spam_checker", None)
+ if provider is not None:
+ self.spam_checker = load_module(provider)
+
+ def default_config(self, **kwargs):
+ return """\
+ # spam_checker:
+ # module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e081840a83..247f18f454 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -126,7 +126,7 @@ class TlsConfig(Config):
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
- if not os.path.exists(tls_private_key_path):
+ if not self.path_exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
@@ -141,7 +141,7 @@ class TlsConfig(Config):
crypto.FILETYPE_PEM, private_key_pem
)
- if not os.path.exists(tls_certificate_path):
+ if not self.path_exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
@@ -159,7 +159,7 @@ class TlsConfig(Config):
certificate_file.write(cert_pem)
- if not os.path.exists(tls_dh_params_path):
+ if not self.path_exists(tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4096c606f1..9e746a28bf 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
("invite", None),
]
- old_list = current_state.content.get("users")
+ old_list = current_state.content.get("users", {})
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
- old_list = current_state.content.get("events")
- new_list = event.content.get("events")
+ old_list = current_state.content.get("events", {})
+ new_list = event.content.get("events", {})
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 56fa9e556e..dccc579eac 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -14,25 +14,100 @@
# limitations under the License.
-def check_event_for_spam(event):
- """Checks if a given event is considered "spammy" by this server.
+class SpamChecker(object):
+ def __init__(self, hs):
+ self.spam_checker = None
- If the server considers an event spammy, then it will be rejected if
- sent by a local user. If it is sent by a user on another server, then
- users receive a blank event.
+ module = None
+ config = None
+ try:
+ module, config = hs.config.spam_checker
+ except:
+ pass
- Args:
- event (synapse.events.EventBase): the event to be checked
+ if module is not None:
+ self.spam_checker = module(config=config)
- Returns:
- bool: True if the event is spammy.
- """
- if not hasattr(event, "content") or "body" not in event.content:
- return False
+ def check_event_for_spam(self, event):
+ """Checks if a given event is considered "spammy" by this server.
- # for example:
- #
- # if "the third flower is green" in event.content["body"]:
- # return True
+ If the server considers an event spammy, then it will be rejected if
+ sent by a local user. If it is sent by a user on another server, then
+ users receive a blank event.
- return False
+ Args:
+ event (synapse.events.EventBase): the event to be checked
+
+ Returns:
+ bool: True if the event is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ return self.spam_checker.check_event_for_spam(event)
+
+ def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ """Checks if a given user may send an invite
+
+ If this method returns false, the invite will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may send an invite, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+
+ def user_may_create_room(self, userid):
+ """Checks if a given user may create a room
+
+ If this method returns false, the creation request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+
+ Returns:
+ bool: True if the user may create a room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room(userid)
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Checks if a given user may create a room alias
+
+ If this method returns false, the association request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_alias (string): The alias to be created
+
+ Returns:
+ bool: True if the user may create a room alias, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+
+ def user_may_publish_room(self, userid, room_id):
+ """Checks if a given user may publish a room to the directory
+
+ If this method returns false, the publish request will be rejected.
+
+ Args:
+ userid (string): The sender's user ID
+ room_id (string): The ID of the room that would be published
+
+ Returns:
+ bool: True if the user may publish the room, otherwise False
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_publish_room(userid, room_id)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index babd9ea078..a0f5d40eb3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,7 +16,6 @@ import logging
from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import spamcheck
from synapse.events.utils import prune_event
from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -144,7 +143,7 @@ class FederationBase(object):
)
return redacted
- if spamcheck.check_event_for_spam(pdu):
+ if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 51e3fdea06..e15228e70b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,14 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
-from synapse.util.async import Linearizer
+from synapse.util import async
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
@@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
@@ -52,7 +53,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
- self._server_linearizer = Linearizer("fed_server")
+ self._server_linearizer = async.Linearizer("fed_server")
+ self._transaction_linearizer = async.Linearizer("fed_txn_handler")
# We cache responses to state queries, as they take a while and often
# come in waves.
@@ -109,25 +111,41 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
+ # keep this as early as possible to make the calculated origin ts as
+ # accurate as possible.
+ request_time = self._clock.time_msec()
+
transaction = Transaction(**transaction_data)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ if not transaction.transaction_id:
+ raise Exception("Transaction missing transaction_id")
+ if not transaction.origin:
+ raise Exception("Transaction missing origin")
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
+ # use a linearizer to ensure that we don't process the same transaction
+ # multiple times in parallel.
+ with (yield self._transaction_linearizer.queue(
+ (transaction.origin, transaction.transaction_id),
+ )):
+ result = yield self._handle_incoming_transaction(
+ transaction, request_time,
+ )
- logger.debug("[%s] Got transaction", transaction.transaction_id)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _handle_incoming_transaction(self, transaction, request_time):
+ """ Process an incoming transaction and return the HTTP response
+
+ Args:
+ transaction (Transaction): incoming transaction
+ request_time (int): timestamp that the HTTP request arrived at
+ Returns:
+ Deferred[(int, object)]: http response code and body
+ """
response = yield self.transaction_actions.have_responded(transaction)
if response:
@@ -140,42 +158,49 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- results = []
-
- for pdu in pdu_list:
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if transaction.origin != get_domain_from_id(pdu.event_id):
- # We continue to accept join events from any server; this is
- # necessary for the federation join dance to work correctly.
- # (When we join over federation, the "helper" server is
- # responsible for sending out the join event, rather than the
- # origin. See bug #1893).
- if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) == 'join'
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, transaction.origin
- )
- continue
- else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, transaction.origin
- )
+ received_pdus_counter.inc_by(len(transaction.pdus))
- try:
- yield self._handle_received_pdu(transaction.origin, pdu)
- results.append({})
- except FederationError as e:
- self.send_failure(e, transaction.origin)
- results.append({"error": str(e)})
- except Exception as e:
- results.append({"error": str(e)})
- logger.exception("Failed to handle PDU")
+ pdus_by_room = {}
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = request_time - int(p["age"])
+ del p["age"]
+
+ event = self.event_from_pdu_json(p)
+ room_id = event.room_id
+ pdus_by_room.setdefault(room_id, []).append(event)
+
+ pdu_results = {}
+
+ # we can process different rooms in parallel (which is useful if they
+ # require callouts to other servers to fetch missing events), but
+ # impose a limit to avoid going too crazy with ram/cpu.
+ @defer.inlineCallbacks
+ def process_pdus_for_room(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ try:
+ yield self._handle_received_pdu(
+ transaction.origin, pdu
+ )
+ pdu_results[event_id] = {}
+ except FederationError as e:
+ logger.warn("Error handling PDU %s: %s", event_id, e)
+ pdu_results[event_id] = {"error": str(e)}
+ except Exception as e:
+ pdu_results[event_id] = {"error": str(e)}
+ logger.exception("Failed to handle PDU %s", event_id)
+
+ yield async.concurrently_execute(
+ process_pdus_for_room, pdus_by_room.keys(),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
@@ -185,17 +210,16 @@ class FederationServer(FederationBase):
edu.content
)
- for failure in getattr(transaction, "pdu_failures", []):
- logger.info("Got failure %r", failure)
-
- logger.debug("Returning: %s", str(results))
+ pdu_failures = getattr(transaction, "pdu_failures", [])
+ for failure in pdu_failures:
+ logger.info("Got failure %r", failure)
response = {
- "pdus": dict(zip(
- (p.event_id for p in pdu_list), results
- )),
+ "pdus": pdu_results,
}
+ logger.debug("Returning: %s", str(response))
+
yield self.transaction_actions.set_response(
transaction,
200, response
@@ -520,6 +544,30 @@ class FederationServer(FederationBase):
Returns (Deferred): completes with None
Raises: FederationError if the signatures / hash do not match
"""
+ # check that it's actually being sent from a valid destination to
+ # workaround bug #1753 in 0.18.5 and 0.18.6
+ if origin != get_domain_from_id(pdu.event_id):
+ # We continue to accept join events from any server; this is
+ # necessary for the federation join dance to work correctly.
+ # (When we join over federation, the "helper" server is
+ # responsible for sending out the join event, rather than the
+ # origin. See bug #1893).
+ if not (
+ pdu.type == 'm.room.member' and
+ pdu.content and
+ pdu.content.get("membership", None) == 'join'
+ ):
+ logger.info(
+ "Discarding PDU %s from invalid origin %s",
+ pdu.event_id, origin
+ )
+ return
+ else:
+ logger.info(
+ "Accepting join PDU %s from %s",
+ pdu.event_id, origin
+ )
+
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 003eaba893..7a3c9cbb70 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -20,8 +20,8 @@ from .persistence import TransactionActions
from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
+from synapse.util import logcontext
from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@@ -231,11 +231,9 @@ class TransactionQueue(object):
(pdu, order)
)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
- @preserve_fn # the caller should not yield on this
+ @logcontext.preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
def send_presence(self, states):
"""Send the new presence states to the appropriate destinations.
@@ -299,7 +297,7 @@ class TransactionQueue(object):
state.user_id: state for state in states
})
- preserve_fn(self._attempt_new_transaction)(destination)
+ self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None):
edu = Edu(
@@ -321,9 +319,7 @@ class TransactionQueue(object):
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
@@ -336,9 +332,7 @@ class TransactionQueue(object):
destination, []
).append(failure)
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
@@ -347,15 +341,24 @@ class TransactionQueue(object):
if not self.can_send_to(destination):
return
- preserve_context_over_fn(
- self._attempt_new_transaction, destination
- )
+ self._attempt_new_transaction(destination)
def get_current_token(self):
return 0
- @defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
+ """Try to start a new transaction to this destination
+
+ If there is already a transaction in progress to this destination,
+ returns immediately. Otherwise kicks off the process of sending a
+ transaction in the background.
+
+ Args:
+ destination (str):
+
+ Returns:
+ None
+ """
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
@@ -368,6 +371,19 @@ class TransactionQueue(object):
)
return
+ logger.debug("TX [%s] Starting transaction loop", destination)
+
+ # Drop the logcontext before starting the transaction. It doesn't
+ # really make sense to log all the outbound transactions against
+ # whatever path led us to this point: that's pretty arbitrary really.
+ #
+ # (this also means we can fire off _perform_transaction without
+ # yielding)
+ with logcontext.PreserveLoggingContext():
+ self._transaction_transmission_loop(destination)
+
+ @defer.inlineCallbacks
+ def _transaction_transmission_loop(self, destination):
pending_pdus = []
try:
self.pending_transactions[destination] = 1
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 52b2a717d2..125d8f3598 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -471,3 +471,384 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ @log_function
+ def get_group_profile(self, destination, group_id, requester_user_id):
+ """Get a group profile
+ """
+ path = PREFIX + "/groups/%s/profile" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_summary(self, destination, group_id, requester_user_id):
+ """Get a group summary
+ """
+ path = PREFIX + "/groups/%s/summary" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ """Get all rooms in a group
+ """
+ path = PREFIX + "/groups/%s/rooms" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+ content):
+ """Add a room to a group
+ """
+ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ """Remove a room from a group
+ """
+ path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users in a group
+ """
+ path = PREFIX + "/groups/%s/users" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ """Get users that have been invited to a group
+ """
+ path = PREFIX + "/groups/%s/invited_users" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def accept_group_invite(self, destination, group_id, user_id, content):
+ """Accept a group invite
+ """
+ path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ """Invite a user to a group
+ """
+ path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def invite_to_group_notification(self, destination, group_id, user_id, content):
+ """Sent by group server to inform a user's server that they have been
+ invited.
+ """
+
+ path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group(self, destination, group_id, requester_user_id,
+ user_id, content):
+ """Remove a user fron a group
+ """
+ path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def remove_user_from_group_notification(self, destination, group_id, user_id,
+ content):
+ """Sent by group server to inform a user's server that they have been
+ kicked from the group.
+ """
+
+ path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def renew_group_attestation(self, destination, group_id, user_id, content):
+ """Sent by either a group server or a user's server to periodically update
+ the attestations
+ """
+
+ path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id, content):
+ """Update a room entry in a group summary
+ """
+ if category_id:
+ path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
+ group_id, category_id, room_id,
+ )
+ else:
+ path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+ category_id):
+ """Delete a room entry in a group summary
+ """
+ if category_id:
+ path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
+ group_id, category_id, room_id,
+ )
+ else:
+ path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_categories(self, destination, group_id, requester_user_id):
+ """Get all categories in a group
+ """
+ path = PREFIX + "/groups/%s/categories" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ """Get category info in a group
+ """
+ path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_category(self, destination, group_id, requester_user_id, category_id,
+ content):
+ """Update a category in a group
+ """
+ path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_category(self, destination, group_id, requester_user_id,
+ category_id):
+ """Delete a category in a group
+ """
+ path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_roles(self, destination, group_id, requester_user_id):
+ """Get all roles in a group
+ """
+ path = PREFIX + "/groups/%s/roles" % (group_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Get a roles info
+ """
+ path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+
+ return self.client.get_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_role(self, destination, group_id, requester_user_id, role_id,
+ content):
+ """Update a role in a group
+ """
+ path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ """Delete a role in a group
+ """
+ path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def update_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id, content):
+ """Update a users entry in a group
+ """
+ if role_id:
+ path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
+ group_id, role_id, user_id,
+ )
+ else:
+ path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def delete_group_summary_user(self, destination, group_id, requester_user_id,
+ user_id, role_id):
+ """Delete a users entry in a group
+ """
+ if role_id:
+ path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
+ group_id, role_id, user_id,
+ )
+ else:
+ path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
+
+ return self.client.delete_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ ignore_backoff=True,
+ )
+
+ def bulk_get_publicised_groups(self, destination, user_ids):
+ """Get the groups a list of users are publicising
+ """
+
+ path = PREFIX + "/get_groups_publicised"
+
+ content = {"user_ids": user_ids}
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a78f01e442..f0778c65c5 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
from synapse.util.logcontext import preserve_fn
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools
import logging
@@ -609,6 +609,493 @@ class FederationVersionServlet(BaseFederationServlet):
}))
+class FederationGroupsProfileServlet(BaseFederationServlet):
+ """Get the basic profile of a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_profile(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+ PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_group_summary(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.update_group_profile(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+ """Get the rooms in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+ """Add/remove room from group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+ """Get the users in a group on behalf of a user
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+ """Get the users that have been invited to a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.get_invited_users_in_group(
+ group_id, requester_user_id
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+ """Ask a group server to invite someone to the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+ """Accept an invitation from the group server
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.accept_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+ """Leave or kick a user from the group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+ """A group server has invited a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "group_id doesn't match origin")
+
+ new_content = yield self.handler.on_invite(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+ """A group server has removed a local user
+ """
+ PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(group_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.user_removed_from_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+ """A group or user's server renews their attestation
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ # We don't need to check auth here as we check the attestation signatures
+
+ new_content = yield self.handler.on_renew_attestation(
+ group_id, user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+ """Add/remove a room from the group summary, with optional category.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+ """Get all categories for a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+ """Add/remove/get a category in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_category(
+ group_id, requester_user_id, category_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.upsert_group_category(
+ group_id, requester_user_id, category_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, category_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_category(
+ group_id, requester_user_id, category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+ """Add/remove/get a role in a group
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ resp = yield self.handler.get_group_role(
+ group_id, requester_user_id, role_id
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_role(
+ group_id, requester_user_id, role_id, content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_role(
+ group_id, requester_user_id, role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+ """Add/remove a user from the group summary, with optional role.
+
+ Matches both:
+ - /groups/:group/summary/users/:user_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATH = (
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty string")
+
+ resp = yield self.handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+ """Get roles in a group
+ """
+ PATH = (
+ "/get_groups_publicised$"
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query):
+ resp = yield self.handler.bulk_get_publicised_groups(
+ content["user_ids"], proxy=False,
+ )
+
+ defer.returnValue((200, resp))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -635,10 +1122,40 @@ FEDERATION_SERVLET_CLASSES = (
FederationVersionServlet,
)
+
ROOM_LIST_CLASSES = (
PublicRoomList,
)
+GROUP_SERVER_SERVLET_CLASSES = (
+ FederationGroupsProfileServlet,
+ FederationGroupsSummaryServlet,
+ FederationGroupsRoomsServlet,
+ FederationGroupsUsersServlet,
+ FederationGroupsInvitedUsersServlet,
+ FederationGroupsInviteServlet,
+ FederationGroupsAcceptInviteServlet,
+ FederationGroupsRemoveUserServlet,
+ FederationGroupsSummaryRoomsServlet,
+ FederationGroupsCategoriesServlet,
+ FederationGroupsCategoryServlet,
+ FederationGroupsRolesServlet,
+ FederationGroupsRoleServlet,
+ FederationGroupsSummaryUsersServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+ FederationGroupsLocalInviteServlet,
+ FederationGroupsRemoveLocalUserServlet,
+ FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+ FederationGroupsRenewAttestaionServlet,
+)
+
def register_servlets(hs, resource, authenticator, ratelimiter):
for servletclass in FEDERATION_SERVLET_CLASSES:
@@ -656,3 +1173,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
ratelimiter=ratelimiter,
server_name=hs.hostname,
).register(resource)
+
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_attestation_renewer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/groups/__init__.py
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
new file mode 100644
index 0000000000..b751cf5e43
--- /dev/null
+++ b/synapse/groups/attestations.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn
+
+from signedjson.sign import sign_json
+
+
+# Default validity duration for new attestations we create
+DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
+
+# Start trying to update our attestations when they come this close to expiring
+UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
+
+
+class GroupAttestationSigning(object):
+ """Creates and verifies group attestations.
+ """
+ def __init__(self, hs):
+ self.keyring = hs.get_keyring()
+ self.clock = hs.get_clock()
+ self.server_name = hs.hostname
+ self.signing_key = hs.config.signing_key[0]
+
+ @defer.inlineCallbacks
+ def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+ """Verifies that the given attestation matches the given parameters.
+
+ An optional server_name can be supplied to explicitly set which server's
+ signature is expected. Otherwise assumes that either the group_id or user_id
+ is local and uses the other's server as the one to check.
+ """
+
+ if not server_name:
+ if get_domain_from_id(group_id) == self.server_name:
+ server_name = get_domain_from_id(user_id)
+ elif get_domain_from_id(user_id) == self.server_name:
+ server_name = get_domain_from_id(group_id)
+ else:
+ raise Exception("Expected either group_id or user_id to be local")
+
+ if user_id != attestation["user_id"]:
+ raise SynapseError(400, "Attestation has incorrect user_id")
+
+ if group_id != attestation["group_id"]:
+ raise SynapseError(400, "Attestation has incorrect group_id")
+ valid_until_ms = attestation["valid_until_ms"]
+
+ # TODO: We also want to check that *new* attestations that people give
+ # us to store are valid for at least a little while.
+ if valid_until_ms < self.clock.time_msec():
+ raise SynapseError(400, "Attestation expired")
+
+ yield self.keyring.verify_json_for_server(server_name, attestation)
+
+ def create_attestation(self, group_id, user_id):
+ """Create an attestation for the group_id and user_id with default
+ validity length.
+ """
+ return sign_json({
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
+ }, self.server_name, self.signing_key)
+
+
+class GroupAttestionRenewer(object):
+ """Responsible for sending and receiving attestation updates.
+ """
+
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.assestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.is_mine_id = hs.is_mine_id
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self._renew_attestations_loop = self.clock.looping_call(
+ self._renew_attestations, 30 * 60 * 1000,
+ )
+
+ @defer.inlineCallbacks
+ def on_renew_attestation(self, group_id, user_id, content):
+ """When a remote updates an attestation
+ """
+ attestation = content["attestation"]
+
+ if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
+ raise SynapseError(400, "Neither user not group are on this server")
+
+ yield self.attestations.verify_attestation(
+ attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+
+ yield self.store.update_remote_attestion(group_id, user_id, attestation)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def _renew_attestations(self):
+ """Called periodically to check if we need to update any of our attestations
+ """
+
+ now = self.clock.time_msec()
+
+ rows = yield self.store.get_attestations_need_renewals(
+ now + UPDATE_ATTESTATION_TIME_MS
+ )
+
+ @defer.inlineCallbacks
+ def _renew_attestation(group_id, user_id):
+ attestation = self.attestations.create_attestation(group_id, user_id)
+
+ if self.is_mine_id(group_id):
+ destination = get_domain_from_id(user_id)
+ else:
+ destination = get_domain_from_id(group_id)
+
+ yield self.transport_client.renew_group_attestation(
+ destination, group_id, user_id,
+ content={"attestation": attestation},
+ )
+
+ yield self.store.update_attestation_renewal(
+ group_id, user_id, attestation
+ )
+
+ for row in rows:
+ group_id = row["group_id"]
+ user_id = row["user_id"]
+
+ preserve_fn(_renew_attestation)(group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
new file mode 100644
index 0000000000..fc4edb7f04
--- /dev/null
+++ b/synapse/groups/groups_server.py
@@ -0,0 +1,803 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
+
+
+import logging
+import urllib
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Federation admin APIs
+# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: Audit log for admins (profile updates, membership changes, users who tried
+# to join but were rejected, etc)
+# TODO: Flairs
+
+
+class GroupsServerHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.attestations = hs.get_groups_attestation_signing()
+ self.transport_client = hs.get_federation_transport_client()
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ @defer.inlineCallbacks
+ def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
+ """Check that the group is ours, and optionally if it exists.
+
+ If group does exist then return group.
+
+ Args:
+ group_id (str)
+ and_exists (bool): whether to also check if group exists
+ and_is_admin (str): whether to also check if given str is a user_id
+ that is an admin
+ """
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Group not on this server")
+
+ group = yield self.store.get_group(group_id)
+ if and_exists and not group:
+ raise SynapseError(404, "Unknown group")
+
+ if and_is_admin:
+ is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ defer.returnValue(group)
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the summary for a group as seen by requester_user_id.
+
+ The group summary consists of the profile of the room, and a curated
+ list of users and rooms. These list *may* be organised by role/category.
+ The roles/categories are ordered, and so are the users/rooms within them.
+
+ A user/room may appear in multiple roles/categories.
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ profile = yield self.get_group_profile(group_id, requester_user_id)
+
+ users, roles = yield self.store.get_users_for_summary_by_role(
+ group_id, include_private=is_user_in_group,
+ )
+
+ # TODO: Add profiles to users
+
+ rooms, categories = yield self.store.get_rooms_for_summary_by_category(
+ group_id, include_private=is_user_in_group,
+ )
+
+ for room_entry in rooms:
+ room_id = room_entry["room_id"]
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+ entry = dict(entry) # so we don't change whats cached
+ entry.pop("room_id", None)
+
+ room_entry["profile"] = entry
+
+ rooms.sort(key=lambda e: e.get("order", 0))
+
+ for entry in users:
+ user_id = entry["user_id"]
+
+ if not self.is_mine_id(requester_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, user_id,
+ )
+
+ user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ entry.update(user_profile)
+
+ users.sort(key=lambda e: e.get("order", 0))
+
+ membership_info = yield self.store.get_users_membership_info_in_group(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue({
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ })
+
+ @defer.inlineCallbacks
+ def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
+ """Add/update a room to the group summary
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
+ """Remove a room from the summary
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ yield self.store.remove_room_from_summary(
+ group_id=group_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id, user_id):
+ """Get all categories in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ categories = yield self.store.get_group_categories(
+ group_id=group_id,
+ )
+ defer.returnValue({"categories": categories})
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, user_id, category_id):
+ """Get a specific category in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ res = yield self.store.get_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_category(self, group_id, user_id, category_id, content):
+ """Add/Update a group category
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ is_public = _parse_visibility_from_contents(content)
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_category(self, group_id, user_id, category_id):
+ """Delete a group category
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ yield self.store.remove_group_category(
+ group_id=group_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id, user_id):
+ """Get all roles in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ roles = yield self.store.get_group_roles(
+ group_id=group_id,
+ )
+ defer.returnValue({"roles": roles})
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, user_id, role_id):
+ """Get a specific role in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ res = yield self.store.get_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def update_group_role(self, group_id, user_id, role_id, content):
+ """Add/update a role in a group
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ profile = content.get("profile")
+
+ yield self.store.upsert_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ is_public=is_public,
+ profile=profile,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_role(self, group_id, user_id, role_id):
+ """Remove role from group
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
+
+ yield self.store.remove_group_role(
+ group_id=group_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
+ content):
+ """Add/update a users entry in the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ order = content.get("order", None)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ order=order,
+ is_public=is_public,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
+ """Remove a user from the group summary
+ """
+ yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ yield self.store.remove_user_from_summary(
+ group_id=group_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def get_group_profile(self, group_id, requester_user_id):
+ """Get the group profile as seen by requester_user_id
+ """
+
+ yield self.check_group_is_ours(group_id)
+
+ group_description = yield self.store.get_group(group_id)
+
+ if group_description:
+ defer.returnValue(group_description)
+ else:
+ raise SynapseError(404, "Unknown group")
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, requester_user_id, content):
+ """Update the group profile
+ """
+ yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id,
+ )
+
+ profile = {}
+ for keyname in ("name", "avatar_url", "short_description",
+ "long_description"):
+ if keyname in content:
+ value = content[keyname]
+ if not isinstance(value, basestring):
+ raise SynapseError(400, "%r value is not a string" % (keyname,))
+ profile[keyname] = value
+
+ yield self.store.update_group_profile(group_id, profile)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get the users in group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for user_result in user_results:
+ g_user_id = user_result["user_id"]
+ is_public = user_result["is_public"]
+
+ entry = {"user_id": g_user_id}
+
+ profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+ entry.update(profile)
+
+ if not is_public:
+ entry["is_public"] = False
+
+ if not self.is_mine_id(g_user_id):
+ attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, g_user_id,
+ )
+
+ chunk.append(entry)
+
+ # TODO: If admin add lists of users whose attestations have timed out
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_user_count_estimate": len(user_results),
+ })
+
+ @defer.inlineCallbacks
+ def get_invited_users_in_group(self, group_id, requester_user_id):
+ """Get the users that have been invited to a group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ if not is_user_in_group:
+ raise SynapseError(403, "User not in group")
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+
+ user_profiles = []
+
+ for user_id in invited_users:
+ user_profile = {
+ "user_id": user_id
+ }
+ try:
+ profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ user_profile.update(profile)
+ except Exception as e:
+ logger.warn("Error getting profile for %s: %s", user_id, e)
+ user_profiles.append(user_profile)
+
+ defer.returnValue({
+ "chunk": user_profiles,
+ "total_user_count_estimate": len(invited_users),
+ })
+
+ @defer.inlineCallbacks
+ def get_rooms_in_group(self, group_id, requester_user_id):
+ """Get the rooms in group as seen by requester_user_id
+
+ This returns rooms in order of decreasing number of joined users
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+ room_results = yield self.store.get_rooms_in_group(
+ group_id, include_private=is_user_in_group,
+ )
+
+ chunk = []
+ for room_result in room_results:
+ room_id = room_result["room_id"]
+ is_public = room_result["is_public"]
+
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users),
+ with_alias=False, allow_private=True,
+ )
+
+ if not entry:
+ continue
+
+ if not is_public:
+ entry["is_public"] = False
+
+ chunk.append(entry)
+
+ chunk.sort(key=lambda e: -e["num_joined_members"])
+
+ defer.returnValue({
+ "chunk": chunk,
+ "total_room_count_estimate": len(room_results),
+ })
+
+ @defer.inlineCallbacks
+ def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+ """Add room to group
+ """
+ RoomID.from_string(room_id) # Ensure valid room id
+
+ yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def remove_room_from_group(self, group_id, requester_user_id, room_id):
+ """Remove room from group
+ """
+ yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ yield self.store.remove_room_from_group(group_id, room_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite_to_group(self, group_id, user_id, requester_user_id, content):
+ """Invite user to group
+ """
+
+ group = yield self.check_group_is_ours(
+ group_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ # TODO: Check if user knocked
+ # TODO: Check if user is already invited
+
+ content = {
+ "profile": {
+ "name": group["name"],
+ "avatar_url": group["avatar_url"],
+ },
+ "inviter": requester_user_id,
+ }
+
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ res = yield groups_local.on_invite(group_id, user_id, content)
+ local_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content.update({
+ "attestation": local_attestation,
+ })
+
+ res = yield self.transport_client.invite_to_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, content
+ )
+
+ user_profile = res.get("user_profile", {})
+ yield self.store.add_remote_profile_cache(
+ user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ if res["state"] == "join":
+ if not self.hs.is_mine_id(user_id):
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=False, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+ elif res["state"] == "invite":
+ yield self.store.add_group_invite(
+ group_id, user_id,
+ )
+ defer.returnValue({
+ "state": "invite"
+ })
+ elif res["state"] == "reject":
+ defer.returnValue({
+ "state": "reject"
+ })
+ else:
+ raise SynapseError(502, "Unknown state returned by HS")
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, user_id, content):
+ """User tries to accept an invite to the group.
+
+ This is different from them asking to join, and so should error if no
+ invite exists (and they're not a member of the group)
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ if not self.store.is_user_invited_to_local_group(group_id, user_id):
+ raise SynapseError(403, "User not invited to group")
+
+ if not self.hs.is_mine_id(user_id):
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+ else:
+ remote_attestation = None
+
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+
+ is_public = _parse_visibility_from_contents(content)
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=False,
+ is_public=is_public,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def knock(self, group_id, user_id, content):
+ """A user requests becoming a member of the group
+ """
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def accept_knock(self, group_id, user_id, content):
+ """Accept a users knock to the room.
+
+ Errors if the user hasn't knocked, rather than inviting them.
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from the group; either a user is leaving or and admin
+ kicked htem.
+ """
+
+ yield self.check_group_is_ours(group_id, and_exists=True)
+
+ is_kick = False
+ if requester_user_id != user_id:
+ is_admin = yield self.store.is_user_admin_in_group(
+ group_id, requester_user_id
+ )
+ if not is_admin:
+ raise SynapseError(403, "User is not admin in group")
+
+ is_kick = True
+
+ yield self.store.remove_user_from_group(
+ group_id, user_id,
+ )
+
+ if is_kick:
+ if self.hs.is_mine_id(user_id):
+ groups_local = self.hs.get_groups_local_handler()
+ yield groups_local.user_removed_from_group(group_id, user_id, {})
+ else:
+ yield self.transport_client.remove_user_from_group_notification(
+ get_domain_from_id(user_id), group_id, user_id, {}
+ )
+
+ if not self.hs.is_mine_id(user_id):
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, content):
+ group = yield self.check_group_is_ours(group_id)
+
+ _validate_group_id(group_id)
+
+ logger.info("Attempting to create group with ID: %r", group_id)
+ if group:
+ raise SynapseError(400, "Group already exists")
+
+ is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
+ if not is_admin:
+ if not self.hs.config.enable_group_creation:
+ raise SynapseError(
+ 403, "Only server admin can create group on this server",
+ )
+ localpart = GroupID.from_string(group_id).localpart
+ if not localpart.startswith(self.hs.config.group_creation_prefix):
+ raise SynapseError(
+ 400,
+ "Can only create groups with prefix %r on this server" % (
+ self.hs.config.group_creation_prefix,
+ ),
+ )
+
+ profile = content.get("profile", {})
+ name = profile.get("name")
+ avatar_url = profile.get("avatar_url")
+ short_description = profile.get("short_description")
+ long_description = profile.get("long_description")
+ user_profile = content.get("user_profile", {})
+
+ yield self.store.create_group(
+ group_id,
+ user_id,
+ name=name,
+ avatar_url=avatar_url,
+ short_description=short_description,
+ long_description=long_description,
+ )
+
+ if not self.hs.is_mine_id(user_id):
+ remote_attestation = content["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ user_id=user_id,
+ group_id=group_id,
+ )
+
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ else:
+ local_attestation = None
+ remote_attestation = None
+
+ yield self.store.add_user_to_group(
+ group_id, user_id,
+ is_admin=True,
+ is_public=True, # TODO
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ )
+
+ if not self.hs.is_mine_id(user_id):
+ yield self.store.add_remote_profile_cache(
+ user_id,
+ displayname=user_profile.get("displayname"),
+ avatar_url=user_profile.get("avatar_url"),
+ )
+
+ defer.returnValue({
+ "group_id": group_id,
+ })
+
+
+def _parse_visibility_from_contents(content):
+ """Given a content for a request parse out whether the entity should be
+ public or not
+ """
+
+ visibility = content.get("visibility")
+ if visibility:
+ vis_type = visibility["type"]
+ if vis_type not in ("public", "private"):
+ raise SynapseError(
+ 400, "Synapse only supports 'public'/'private' visibility"
+ )
+ is_public = vis_type == "public"
+ else:
+ is_public = True
+
+ return is_public
+
+
+def _validate_group_id(group_id):
+ """Validates the group ID is valid for creation on this home server
+ """
+ localpart = GroupID.from_string(group_id).localpart
+
+ if localpart.lower() != localpart:
+ raise SynapseError(400, "Group ID must be lower case")
+
+ if urllib.quote(localpart.encode('utf-8')) != localpart:
+ raise SynapseError(
+ 400,
+ "Group ID can only contain characters a-z, 0-9, or '_-./'",
+ )
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 5ad408f549..53213cdccf 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -20,7 +20,6 @@ from .room import (
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .federation import FederationHandler
-from .profile import ProfileHandler
from .directory import DirectoryHandler
from .admin import AdminHandler
from .identity import IdentityHandler
@@ -52,7 +51,6 @@ class Handlers(object):
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.federation_handler = FederationHandler(hs)
- self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
self.identity_handler = IdentityHandler(hs)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 943554ce98..a0464ae5c0 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler):
"directory", self.on_directory_query
)
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services
@@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler):
# association creation for human users
# TODO(erikj): Do user auth.
+ if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ raise SynapseError(
+ 403, "This user is not permitted to create this alias",
+ )
+
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
@@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
room_id (str)
visibility (str): "public" or "private"
"""
+ if not self.spam_checker.user_may_publish_room(
+ requester.user.to_string(), room_id
+ ):
+ raise AuthError(
+ 403,
+ "This user is not permitted to publish rooms to the room list"
+ )
+
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 18f87cad67..7711cded01 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""Contains handlers for federation events."""
-import synapse.util.logcontext
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -26,10 +25,7 @@ from synapse.api.errors import (
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
- preserve_fn, preserve_context_over_deferred
-)
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, Linearizer
@@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool()
+ self.spam_checker = hs.get_spam_checker()
self.replication_layer.set_handler(self)
@@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
self.room_queues[pdu.room_id].append((pdu, origin))
return
+ # If we're no longer in the room just ditch the event entirely. This
+ # is probably an old server that has come back and thinks we're still
+ # in the room (or we've been rejoined to the room by a state reset).
+ #
+ # If we were never in the room then maybe our database got vaped and
+ # we should check if we *are* in fact in the room. If we are then we
+ # can magically rejoin the room.
+ is_in_room = yield self.auth.check_host_in_room(
+ pdu.room_id,
+ self.server_name
+ )
+ if not is_in_room:
+ was_in_room = yield self.store.was_host_joined(
+ pdu.room_id, self.server_name,
+ )
+ if was_in_room:
+ logger.info(
+ "Ignoring PDU %s for room %s from %s as we've left the room!",
+ pdu.event_id, pdu.room_id, origin,
+ )
+ return
+
state = None
auth_chain = []
@@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch
)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.preserve_fn(self.replication_layer.get_pdu)(
[dest],
event_id,
outlier=True,
@@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- states = yield preserve_context_over_deferred(defer.gatherResults([
- preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
- for e in event_ids
- ]))
+ states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
+ room_id, [e]
+ )
+ for e in event_ids
+ ], consumeErrors=True,
+ ))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
@@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
- room_queue
- )
+ logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
defer.returnValue(True)
@@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
"""
event = pdu
+ if event.state_key is None:
+ raise SynapseError(400, "The invite event did not have a state key")
+
is_blocked = yield self.store.is_room_blocked(event.room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ if not self.spam_checker.user_may_invite(
+ event.sender, event.state_key, event.room_id,
+ ):
+ raise SynapseError(
+ 403, "This user is not permitted to send invites to this server/user"
+ )
+
membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event")
@@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
if sender_domain != origin:
raise SynapseError(400, "The invite event was not from the server sending it")
- if event.state_key is None:
- raise SynapseError(400, "The invite event did not have a state key")
-
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
@@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.pusher_pool.on_new_notifications)(
+ logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
event_stream_id, max_stream_id
)
@@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
- contexts = yield preserve_context_over_deferred(defer.gatherResults(
+ contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._prep_event)(
+ logcontext.preserve_fn(self._prep_event)(
origin,
ev_info["event"],
state=ev_info.get("state"),
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
- ]
+ ], consumeErrors=True,
))
yield self.store.persist_events(
@@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
- different_events = yield preserve_context_over_deferred(defer.gatherResults(
- [
- preserve_fn(self.store.get_event)(
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.preserve_fn(self.store.get_event)(
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
new file mode 100644
index 0000000000..6699d0888f
--- /dev/null
+++ b/synapse/handlers/groups_local.py
@@ -0,0 +1,417 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def _create_rerouter(func_name):
+ """Returns a function that looks at the group id and calls the function
+ on federation or the local group server if the group is local
+ """
+ def f(self, group_id, *args, **kwargs):
+ if self.is_mine_id(group_id):
+ return getattr(self.groups_server_handler, func_name)(
+ group_id, *args, **kwargs
+ )
+ else:
+ destination = get_domain_from_id(group_id)
+ return getattr(self.transport_client, func_name)(
+ destination, group_id, *args, **kwargs
+ )
+ return f
+
+
+class GroupsLocalHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.room_list_handler = hs.get_room_list_handler()
+ self.groups_server_handler = hs.get_groups_server_handler()
+ self.transport_client = hs.get_federation_transport_client()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.keyring = hs.get_keyring()
+ self.is_mine_id = hs.is_mine_id
+ self.signing_key = hs.config.signing_key[0]
+ self.server_name = hs.hostname
+ self.notifier = hs.get_notifier()
+ self.attestations = hs.get_groups_attestation_signing()
+
+ self.profile_handler = hs.get_profile_handler()
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ # The following functions merely route the query to the local groups server
+ # or federation depending on if the group is local or remote
+
+ get_group_profile = _create_rerouter("get_group_profile")
+ update_group_profile = _create_rerouter("update_group_profile")
+ get_rooms_in_group = _create_rerouter("get_rooms_in_group")
+
+ get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
+
+ add_room_to_group = _create_rerouter("add_room_to_group")
+ remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+ update_group_summary_room = _create_rerouter("update_group_summary_room")
+ delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+ update_group_category = _create_rerouter("update_group_category")
+ delete_group_category = _create_rerouter("delete_group_category")
+ get_group_category = _create_rerouter("get_group_category")
+ get_group_categories = _create_rerouter("get_group_categories")
+
+ update_group_summary_user = _create_rerouter("update_group_summary_user")
+ delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+ update_group_role = _create_rerouter("update_group_role")
+ delete_group_role = _create_rerouter("delete_group_role")
+ get_group_role = _create_rerouter("get_group_role")
+ get_group_roles = _create_rerouter("get_group_roles")
+
+ @defer.inlineCallbacks
+ def get_group_summary(self, group_id, requester_user_id):
+ """Get the group summary for a group.
+
+ If the group is remote we check that the users have valid attestations.
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_group_summary(
+ group_id, requester_user_id
+ )
+ else:
+ res = yield self.transport_client.get_group_summary(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ group_server_name = get_domain_from_id(group_id)
+
+ # Loop through the users and validate the attestations.
+ chunk = res["users_section"]["users"]
+ valid_users = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_users.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["users_section"]["users"] = valid_users
+
+ res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
+ res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
+
+ # Add `is_publicised` flag to indicate whether the user has publicised their
+ # membership of the group on their profile
+ result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+ is_publicised = group_id in result
+
+ res.setdefault("user", {})["is_publicised"] = is_publicised
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, content):
+ """Create a group
+ """
+
+ logger.info("Asking to create group with ID: %r", group_id)
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.create_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+
+ res = yield self.transport_client.create_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ is_publicised = content.get("publicise", False)
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=True,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get users in a group
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+ defer.returnValue(res)
+
+ group_server_name = get_domain_from_id(group_id)
+
+ res = yield self.transport_client.get_users_in_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ )
+
+ chunk = res["chunk"]
+ valid_entries = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_entries.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["chunk"] = valid_entries
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, user_id, content):
+ """Request to join a group
+ """
+ raise NotImplementedError() # TODO
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, user_id, content):
+ """Accept an invite to a group
+ """
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.accept_invite(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.accept_group_invite(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
+ def invite(self, group_id, user_id, requester_user_id, config):
+ """Invite a user to a group
+ """
+ content = {
+ "requester_user_id": requester_user_id,
+ "config": config,
+ }
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.invite_to_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ res = yield self.transport_client.invite_to_group(
+ get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def on_invite(self, group_id, user_id, content):
+ """One of our users were invited to a group
+ """
+ # TODO: Support auto join and rejection
+
+ if not self.is_mine_id(user_id):
+ raise SynapseError(400, "User not on this server")
+
+ local_profile = {}
+ if "profile" in content:
+ if "name" in content["profile"]:
+ local_profile["name"] = content["profile"]["name"]
+ if "avatar_url" in content["profile"]:
+ local_profile["avatar_url"] = content["profile"]["avatar_url"]
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="invite",
+ content={"profile": local_profile, "inviter": content["inviter"]},
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+ try:
+ user_profile = yield self.profile_handler.get_profile(user_id)
+ except Exception as e:
+ logger.warn("No profile for user %s: %s", user_id, e)
+ user_profile = {}
+
+ defer.returnValue({"state": "invite", "user_profile": user_profile})
+
+ @defer.inlineCallbacks
+ def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ """Remove a user from a group
+ """
+ if user_id == requester_user_id:
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ # TODO: Should probably remember that we tried to leave so that we can
+ # retry if the group server is currently down.
+
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+ else:
+ content["requester_user_id"] = requester_user_id
+ res = yield self.transport_client.remove_user_from_group(
+ get_domain_from_id(group_id), group_id, requester_user_id,
+ user_id, content,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def user_removed_from_group(self, group_id, user_id, content):
+ """One of our users was removed/kicked from a group
+ """
+ # TODO: Check if user in group
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="leave",
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_groups(self, user_id):
+ group_ids = yield self.store.get_joined_groups(user_id)
+ defer.returnValue({"groups": group_ids})
+
+ @defer.inlineCallbacks
+ def get_publicised_groups_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ result = yield self.store.get_publicised_groups_for_user(user_id)
+ defer.returnValue({"groups": result})
+ else:
+ result = yield self.transport_client.get_publicised_groups_for_user(
+ get_domain_from_id(user_id), user_id
+ )
+ # TODO: Verify attestations
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ destinations = {}
+ local_users = set()
+
+ for user_id in user_ids:
+ if self.hs.is_mine_id(user_id):
+ local_users.add(user_id)
+ else:
+ destinations.setdefault(
+ get_domain_from_id(user_id), set()
+ ).add(user_id)
+
+ if not proxy and destinations:
+ raise SynapseError(400, "Some user_ids are not local")
+
+ results = {}
+ failed_results = []
+ for destination, dest_user_ids in destinations.iteritems():
+ try:
+ r = yield self.transport_client.bulk_get_publicised_groups(
+ destination, list(dest_user_ids),
+ )
+ results.update(r["users"])
+ except Exception:
+ failed_results.extend(dest_user_ids)
+
+ for uid in local_users:
+ results[uid] = yield self.store.get_publicised_groups_for_user(
+ uid
+ )
+
+ defer.returnValue({"users": results})
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da18bf23db..28792788d9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.events import spamcheck
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -26,6 +26,7 @@ from synapse.types import (
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
+from synapse.util.frozenutils import unfreeze
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
+ self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock()
@@ -58,6 +60,8 @@ class MessageHandler(BaseHandler):
self.action_generator = hs.get_action_generator()
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
@@ -210,7 +214,7 @@ class MessageHandler(BaseHandler):
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
- profile = self.hs.get_handlers().profile_handler
+ profile = self.profile_handler
content = builder.content
try:
@@ -322,9 +326,12 @@ class MessageHandler(BaseHandler):
txn_id=txn_id
)
- if spamcheck.check_event_for_spam(event):
+ spam_error = self.spam_checker.check_event_for_spam(event)
+ if spam_error:
+ if not isinstance(spam_error, basestring):
+ spam_error = "Spam is not permitted here"
raise SynapseError(
- 403, "Spam is not permitted here", Codes.FORBIDDEN
+ 403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
@@ -418,6 +425,51 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()]
)
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
+
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or becuase there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
+
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ })
+
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
@@ -509,7 +561,7 @@ class MessageHandler(BaseHandler):
# Ensure that we can round trip before trying to persist in db
try:
- dump = ujson.dumps(event.content)
+ dump = ujson.dumps(unfreeze(event.content))
ujson.loads(dump)
except:
logger.exception("Failed to encode content: %r", event.content)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 7abee98dea..e56e0a52bf 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -19,14 +19,15 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
class ProfileHandler(BaseHandler):
+ PROFILE_UPDATE_MS = 60 * 1000
+ PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs):
super(ProfileHandler, self).__init__(hs)
@@ -36,6 +37,63 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query
)
+ self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
+
+ @defer.inlineCallbacks
+ def get_profile(self, user_id):
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ try:
+ result = yield self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ defer.returnValue(result)
+ except CodeMessageException as e:
+ if e.code != 404:
+ logger.exception("Failed to get displayname")
+
+ raise
+
+ @defer.inlineCallbacks
+ def get_profile_from_cache(self, user_id):
+ """Get the profile information from our local cache. If the user is
+ ours then the profile information will always be corect. Otherwise,
+ it may be out of date/missing.
+ """
+ target_user = UserID.from_string(user_id)
+ if self.hs.is_mine(target_user):
+ displayname = yield self.store.get_profile_displayname(
+ target_user.localpart
+ )
+ avatar_url = yield self.store.get_profile_avatar_url(
+ target_user.localpart
+ )
+
+ defer.returnValue({
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ })
+ else:
+ profile = yield self.store.get_from_remote_profile_cache(user_id)
+ defer.returnValue(profile or {})
+
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
@@ -182,3 +240,44 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s",
room_id, str(e.message)
)
+
+ def _update_remote_profile_cache(self):
+ """Called periodically to check profiles of remote users we haven't
+ checked in a while.
+ """
+ entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+ last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
+ )
+
+ for user_id, displayname, avatar_url in entries:
+ is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+ user_id,
+ )
+ if not is_subscribed:
+ yield self.store.maybe_delete_remote_profile_cache(user_id)
+ continue
+
+ try:
+ profile = yield self.federation.make_query(
+ destination=get_domain_from_id(user_id),
+ query_type="profile",
+ args={
+ "user_id": user_id,
+ },
+ ignore_backoff=True,
+ )
+ except:
+ logger.exception("Failed to get avatar_url")
+
+ yield self.store.update_remote_profile_cache(
+ user_id, displayname, avatar_url
+ )
+ continue
+
+ new_name = profile.get("displayname")
+ new_avatar = profile.get("avatar_url")
+
+ # We always hit update to update the last_check timestamp
+ yield self.store.update_remote_profile_cache(
+ user_id, new_name, new_avatar
+ )
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e1cd3a48e9..0525765272 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.util import logcontext
from ._base import BaseHandler
@@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
is_new = yield self._handle_new_receipts([receipt])
if is_new:
+ # fire off a process in the background to send the receipt to
+ # remote servers
self._push_remotes([receipt])
@defer.inlineCallbacks
@@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
+ @logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ee3a2269a8..560fb36254 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
+ self.profile_handler = hs.get_profile_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
- profile_handler = self.hs.get_handlers().profile_handler
- yield profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True,
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5698d28088..535ba9517c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler):
},
}
+ def __init__(self, hs):
+ super(RoomCreationHandler, self).__init__(hs)
+
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True):
""" Creates a new room.
@@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
+ if not self.spam_checker.user_may_create_room(user_id):
+ raise SynapseError(403, "You are not permitted to create rooms")
+
if ratelimit:
yield self.ratelimit(requester)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 516cd9a6ac..41e1781df7 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -276,13 +276,14 @@ class RoomListHandler(BaseHandler):
# We've already got enough, so lets just drop it.
return
- result = yield self._generate_room_entry(room_id, num_joined_users)
+ result = yield self.generate_room_entry(room_id, num_joined_users)
if result and _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def _generate_room_entry(self, room_id, num_joined_users, cache_context):
+ def generate_room_entry(self, room_id, num_joined_users, cache_context,
+ with_alias=True, allow_private=False):
"""Returns the entry for a room
"""
result = {
@@ -316,14 +317,15 @@ class RoomListHandler(BaseHandler):
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
- if join_rule and join_rule != JoinRules.PUBLIC:
+ if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
- aliases = yield self.store.get_aliases_for_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- if aliases:
- result["aliases"] = aliases
+ if with_alias:
+ aliases = yield self.store.get_aliases_for_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ if aliases:
+ result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9a498c2d3e..970fec0666 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
+ self.profile_handler = hs.get_profile_handler()
+
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
+ self.spam_checker = hs.get_spam_checker()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
@@ -210,12 +213,26 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- if (effective_membership_state == "invite" and
- self.hs.config.block_non_admin_invites):
+ if effective_membership_state == "invite":
+ block_invite = False
is_requester_admin = yield self.auth.is_server_admin(
requester.user,
)
if not is_requester_admin:
+ if self.hs.config.block_non_admin_invites:
+ logger.info(
+ "Blocking invite: user is not admin and non-admin "
+ "invites disabled"
+ )
+ block_invite = True
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(), target.to_string(), room_id,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ block_invite = True
+
+ if block_invite:
raise SynapseError(
403, "Invites have been disabled on this server",
)
@@ -267,7 +284,7 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN
- profile = self.hs.get_handlers().profile_handler
+ profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index dd0ec00ae6..219529936f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
return True
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+ "join",
+ "invite",
+ "leave",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ return bool(self.join or self.invite or self.leave)
+
+
class DeviceLists(collections.namedtuple("DeviceLists", [
"changed", # list of user_ids whose devices may have changed
"left", # list of user_ids whose devices we no longer track
@@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"device_lists", # List of user_ids whose devices have chanegd
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
+ "groups",
])):
__slots__ = []
@@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or
self.account_data or
self.to_device or
- self.device_lists
+ self.device_lists or
+ self.groups
)
@@ -595,6 +608,8 @@ class SyncHandler(object):
user_id, device_id
)
+ yield self._generate_sync_entry_for_groups(sync_result_builder)
+
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -603,10 +618,57 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
+ groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
+ @measure_func("_generate_sync_entry_for_groups")
+ @defer.inlineCallbacks
+ def _generate_sync_entry_for_groups(self, sync_result_builder):
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+ now_token = sync_result_builder.now_token
+
+ if since_token and since_token.groups_key:
+ results = yield self.store.get_groups_changes_for_user(
+ user_id, since_token.groups_key, now_token.groups_key,
+ )
+ else:
+ results = yield self.store.get_all_groups_for_user(
+ user_id, now_token.groups_key,
+ )
+
+ invited = {}
+ joined = {}
+ left = {}
+ for result in results:
+ membership = result["membership"]
+ group_id = result["group_id"]
+ gtype = result["type"]
+ content = result["content"]
+
+ if membership == "join":
+ if gtype == "membership":
+ # TODO: Add profile
+ content.pop("membership", None)
+ joined[group_id] = content["content"]
+ else:
+ joined.setdefault(group_id, {})[gtype] = content
+ elif membership == "invite":
+ if gtype == "membership":
+ content.pop("membership", None)
+ invited[group_id] = content["content"]
+ else:
+ if gtype == "membership":
+ left[group_id] = content["content"]
+
+ sync_result_builder.groups = GroupsSyncResult(
+ join=joined,
+ invite=invited,
+ leave=left,
+ )
+
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder,
@@ -1368,6 +1430,7 @@ class SyncResultBuilder(object):
self.invited = []
self.archived = []
self.device = []
+ self.groups = None
self.to_device = []
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 241b17f2cb..a97532162f 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -354,16 +354,28 @@ def _get_hosts_for_srv_record(dns_client, host):
return res[0]
- def eb(res):
- res.trap(DNSNameError)
- return []
+ def eb(res, record_type):
+ if res.check(DNSNameError):
+ return []
+ logger.warn("Error looking up %s for %s: %s",
+ record_type, host, res, res.value)
+ return res
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
- results = yield defer.gatherResults([d1, d2], consumeErrors=True)
+ results = yield defer.DeferredList(
+ [d1, d2], consumeErrors=True)
+
+ # if all of the lookups failed, raise an exception rather than blowing out
+ # the cache with an empty result.
+ if results and all(s == defer.FAILURE for (s, _) in results):
+ defer.returnValue(results[0][1])
+
+ for (success, result) in results:
+ if success == defer.FAILURE:
+ continue
- for result in results:
for answer in result:
if not answer.payload:
continue
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 747a791f83..8c8b7fa656 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
raise
logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
+ "{%s} Sending request failed to %s: %s %s: %s",
txn_id,
destination,
method,
url_bytes,
- type(e).__name__,
_flatten_response_never_received(e),
)
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
+ log_result = _flatten_response_never_received(e)
if retries_left and not timeout:
if long_retries:
@@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False):
+ timeout=None, ignore_backoff=False, args={}):
""" Sends the specifed json data using POST
Args:
@@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object):
destination,
"POST",
path,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
@@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object):
"""
logger.debug("get_json args: %s", args)
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, basestring):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
@@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object):
destination,
"GET",
path,
- query_bytes=query_bytes,
+ query_bytes=encode_query_args(args),
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
@@ -461,6 +452,52 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
+ def delete_json(self, destination, path, long_retries=False,
+ timeout=None, ignore_backoff=False, args={}):
+ """Send a DELETE request to the remote expecting some json response
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
+ timeout(int): How long to try (in ms) the destination for before
+ giving up. None indicates no timeout.
+ ignore_backoff (bool): true to ignore the historical backoff data and
+ try the request anyway.
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body.
+
+ Fails with ``HTTPRequestException`` if we get an HTTP response
+ code >= 300.
+
+ Fails with ``NotRetryingDestination`` if we are not yet ready
+ to retry this server.
+ """
+
+ response = yield self._request(
+ destination,
+ "DELETE",
+ path,
+ query_bytes=encode_query_args(args),
+ headers_dict={"Content-Type": ["application/json"]},
+ long_retries=long_retries,
+ timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ )
+
+ if 200 <= response.code < 300:
+ # We need to update the transactions table to say it was sent?
+ check_content_type_is_json(response.headers)
+
+ with logcontext.PreserveLoggingContext():
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None,
ignore_backoff=False):
@@ -578,12 +615,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
- return ", ".join(
+ reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
+
+ return "%s:[%s]" % (type(e).__name__, reasons)
else:
- return "%s: %s" % (type(e).__name__, e.message,)
+ return repr(e)
def check_content_type_is_json(headers):
@@ -610,3 +649,15 @@ def check_content_type_is_json(headers):
raise RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
)
+
+
+def encode_query_args(args):
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, basestring):
+ vs = [vs]
+ encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+ query_bytes = urllib.urlencode(encoded_args, True)
+
+ return query_bytes
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 7ef3d526b1..8a27e3b422 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -145,7 +145,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
- send_cors=True
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ version_string=self.version_string,
)
finally:
try:
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 85effdfa46..7a18afe5f9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
}
]
},
+ {
+ 'rule_id': 'global/override/.m.rule.roomnotif',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern': '@room',
+ '_id': '_roomnotif_content',
+ },
+ {
+ 'kind': 'sender_notification_permission',
+ 'key': 'room',
+ '_id': '_roomnotif_pl',
+ },
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': True,
+ }
+ ]
+ }
]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index b0d64aa6c4..425a017bdf 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,11 +20,13 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.event_auth import get_user_power_level
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import get_metrics_for
from synapse.util.caches import metrics as cache_metrics
from synapse.util.caches.descriptors import cached
from synapse.util.async import Linearizer
+from synapse.state import POWER_KEY
from collections import namedtuple
@@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
self.room_push_rule_cache_metrics = cache_metrics.register_cache(
"cache",
@@ -109,6 +113,29 @@ class BulkPushRuleEvaluator(object):
)
@defer.inlineCallbacks
+ def _get_power_levels_and_sender_level(self, event, context):
+ pl_event_id = context.prev_state_ids.get(POWER_KEY)
+ if pl_event_id:
+ # fastpath: if there's a power level event, that's all we need, and
+ # not having a power level event is an extreme edge case
+ pl_event = yield self.store.get_event(pl_event_id)
+ auth_events = {POWER_KEY: pl_event}
+ else:
+ auth_events_ids = yield self.auth.compute_auth_events(
+ event, context.prev_state_ids, for_verification=False,
+ )
+ auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_events.itervalues()
+ }
+
+ sender_level = get_user_power_level(event.sender, auth_events)
+
+ pl_event = auth_events.get(POWER_KEY)
+
+ defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+
+ @defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and return
the results
@@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object):
event, context
)
- evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
+ (power_levels, sender_power_level) = (
+ yield self._get_power_levels_and_sender_level(event, context)
+ )
+
+ evaluator = PushRuleEvaluatorForEvent(
+ event, len(room_members), sender_power_level, power_levels,
+ )
condition_cache = {}
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 172c27c137..3601f2d365 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count):
+ return _test_ineq_condition(condition, room_member_count)
+
+
+def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+ notif_level_key = condition.get('key')
+ if notif_level_key is None:
+ return False
+
+ notif_levels = power_levels.get('notifications', {})
+ room_notif_level = notif_levels.get(notif_level_key, 50)
+
+ return sender_power_level >= room_notif_level
+
+
+def _test_ineq_condition(condition, number):
if 'is' not in condition:
return False
m = INEQUALITY_EXPR.match(condition['is'])
@@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
rhs = int(rhs)
if ineq == '' or ineq == '==':
- return room_member_count == rhs
+ return number == rhs
elif ineq == '<':
- return room_member_count < rhs
+ return number < rhs
elif ineq == '>':
- return room_member_count > rhs
+ return number > rhs
elif ineq == '>=':
- return room_member_count >= rhs
+ return number >= rhs
elif ineq == '<=':
- return room_member_count <= rhs
+ return number <= rhs
else:
return False
@@ -65,9 +81,11 @@ def tweaks_for_actions(actions):
class PushRuleEvaluatorForEvent(object):
- def __init__(self, event, room_member_count):
+ def __init__(self, event, room_member_count, sender_power_level, power_levels):
self._event = event
self._room_member_count = room_member_count
+ self._sender_power_level = sender_power_level
+ self._power_levels = power_levels
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
@@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
return _room_member_count(
self._event, condition, self._room_member_count
)
+ elif condition['kind'] == 'sender_notification_permission':
+ return _sender_notification_permission(
+ self._event, condition, self._sender_power_level, self._power_levels,
+ )
else:
return True
@@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary):
r,
)
if word_boundary:
- r = r"\b%s\b" % (r,)
+ r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
@@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE)
elif word_boundary:
r = re.escape(glob)
- r = r"\b%s\b" % (r,)
+ r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
@@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE)
+def _re_word_boundary(r):
+ """
+ Adds word boundary characters to the start and end of an
+ expression to require that the match occur as a whole word,
+ but do so respecting the fact that strings starting or ending
+ with non-word characters will change word boundaries.
+ """
+ # we can't use \b as it chokes on unicode. however \W seems to be okay
+ # as shorthand for [^0-9A-Za-z_].
+ return r"(^|\W)%s(\W|$)" % (r,)
+
+
def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
new file mode 100644
index 0000000000..0bc4bce5b0
--- /dev/null
+++ b/synapse/replication/slave/storage/groups.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedGroupServerStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+
+ self.hs = hs
+
+ self._group_updates_id_gen = SlavedIdTracker(
+ db_conn, "local_group_updates", "stream_id",
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+ )
+
+ get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
+ get_group_stream_token = DataStore.get_group_stream_token.__func__
+ get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
+
+ def stream_positions(self):
+ result = super(SlavedGroupServerStore, self).stream_positions()
+ result["groups"] = self._group_updates_id_gen.get_current_token()
+ return result
+
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "groups":
+ self._group_updates_id_gen.advance(token)
+ for row in rows:
+ self._group_updates_stream_cache.entity_has_changed(
+ row.user_id, token
+ )
+
+ return super(SlavedGroupServerStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 3ea3ca5a6f..6c1beca4e3 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -160,7 +160,11 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME, stream.last_token, stream.upto_token
)
- updates, current_token = yield stream.get_updates()
+ try:
+ updates, current_token = yield stream.get_updates()
+ except:
+ logger.info("Failed to handle stream %s", stream.NAME)
+ raise
logger.debug(
"Sending %d updates to %d connections",
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index fbafe12cc2..4c60bf79f9 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str
"event_id", # str, optional
))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+ "group_id", # str
+ "user_id", # str
+ "type", # str
+ "content", # dict
+))
class Stream(object):
@@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs)
+class GroupServerStream(Stream):
+ NAME = "groups"
+ ROW_TYPE = GroupsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_group_stream_token
+ self.update_function = store.get_all_groups_changes
+
+ super(GroupServerStream, self).__init__(hs)
+
+
STREAMS_MAP = {
stream.NAME: stream
for stream in (
@@ -482,5 +501,6 @@ STREAMS_MAP = {
TagAccountDataStream,
AccountDataStream,
CurrentStateDeltaStream,
+ GroupServerStream,
)
}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3d809d181b..16f5a73b95 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import (
thirdparty,
sendtodevice,
user_directory,
+ groups,
)
from synapse.http.server import JsonResource
@@ -102,3 +103,4 @@ class ClientRestResource(JsonResource):
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
+ groups.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1a5045c9ec..d7edc34245 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
@@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
except:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
@@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
except:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_avatar_url(
+ yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cd388770c8..6c379d53ac 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
- self.state = hs.get_state_handler()
+ self.message_handler = hs.get_handlers().message_handler
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.message_handler.get_joined_members(
+ requester, room_id,
+ )
defer.returnValue((200, {
- "joined": {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in users_with_profile.iteritems()
- }
+ "joined": users_with_profile,
}))
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
new file mode 100644
index 0000000000..d11bccc1da
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -0,0 +1,717 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import GroupID
+
+from ._base import client_v2_patterns
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GroupServlet(RestServlet):
+ """Get the group profile
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+
+ def __init__(self, hs):
+ super(GroupServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
+
+ defer.returnValue((200, group_description))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ yield self.groups_handler.update_group_profile(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class GroupSummaryServlet(RestServlet):
+ """Get the full group summary
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+
+ def __init__(self, hs):
+ super(GroupSummaryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
+
+ defer.returnValue((200, get_group_summary))
+
+
+class GroupSummaryRoomsCatServlet(RestServlet):
+ """Update/delete a rooms entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryRoomsCatServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_room(
+ group_id, user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_room(
+ group_id, user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoryServlet(RestServlet):
+ """Get/add/update/delete a group category
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_category(
+ group_id, user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_category(
+ group_id, user_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_category(
+ group_id, user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoriesServlet(RestServlet):
+ """Get all group categories
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoriesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_categories(
+ group_id, user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupRoleServlet(RestServlet):
+ """Get/add/update/delete a group role
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_role(
+ group_id, user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_role(
+ group_id, user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_role(
+ group_id, user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRolesServlet(RestServlet):
+ """Get all group roles
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRolesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_roles(
+ group_id, user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupSummaryUsersRoleServlet(RestServlet):
+ """Update/delete a user's entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/users/:room_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryUsersRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRoomServlet(RestServlet):
+ """Get all rooms in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+
+ def __init__(self, hs):
+ super(GroupRoomServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupUsersServlet(RestServlet):
+ """Get all users in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+
+ def __init__(self, hs):
+ super(GroupUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_users_in_group(group_id, user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupInvitedUsersServlet(RestServlet):
+ """Get users invited to a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+
+ def __init__(self, hs):
+ super(GroupInvitedUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupCreateServlet(RestServlet):
+ """Create a group
+ """
+ PATTERNS = client_v2_patterns("/create_group$")
+
+ def __init__(self, hs):
+ super(GroupCreateServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.server_name = hs.hostname
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ # TODO: Create group on remote server
+ content = parse_json_object_from_request(request)
+ localpart = content.pop("localpart")
+ group_id = GroupID.create(localpart, self.server_name).to_string()
+
+ result = yield self.groups_handler.create_group(group_id, user_id, content)
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsServlet(RestServlet):
+ """Add a room to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.add_room_to_group(
+ group_id, user_id, room_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.remove_room_from_group(
+ group_id, user_id, room_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersInviteServlet(RestServlet):
+ """Invite a user to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.store = hs.get_datastore()
+ self.is_mine_id = hs.is_mine_id
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ config = content.get("config", {})
+ result = yield self.groups_handler.invite(
+ group_id, user_id, requester_user_id, config,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersKickServlet(RestServlet):
+ """Kick a user from the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersKickServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfLeaveServlet(RestServlet):
+ """Leave a joined group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/leave$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfLeaveServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, requester_user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfJoinServlet(RestServlet):
+ """Attempt to join a group, or knock
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/join$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfJoinServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.join_group(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfAcceptInviteServlet(RestServlet):
+ """Accept a group invite
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfAcceptInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.accept_invite(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfUpdatePublicityServlet(RestServlet):
+ """Update whether we publicise a users membership of a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfUpdatePublicityServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ publicise = content["publicise"]
+ yield self.store.update_group_publicity(
+ group_id, requester_user_id, publicise,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class PublicisedGroupsForUserServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ yield self.auth.get_user_by_req(request)
+
+ result = yield self.groups_handler.get_publicised_groups_for_user(
+ user_id
+ )
+
+ defer.returnValue((200, result))
+
+
+class PublicisedGroupsForUsersServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+ user_ids = content["user_ids"]
+
+ result = yield self.groups_handler.bulk_get_publicised_groups(
+ user_ids
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupsForUserServlet(RestServlet):
+ """Get all groups the logged in user is joined to
+ """
+ PATTERNS = client_v2_patterns(
+ "/joined_groups$"
+ )
+
+ def __init__(self, hs):
+ super(GroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_joined_groups(user_id)
+
+ defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+ GroupServlet(hs).register(http_server)
+ GroupSummaryServlet(hs).register(http_server)
+ GroupInvitedUsersServlet(hs).register(http_server)
+ GroupUsersServlet(hs).register(http_server)
+ GroupRoomServlet(hs).register(http_server)
+ GroupCreateServlet(hs).register(http_server)
+ GroupAdminRoomsServlet(hs).register(http_server)
+ GroupAdminUsersInviteServlet(hs).register(http_server)
+ GroupAdminUsersKickServlet(hs).register(http_server)
+ GroupSelfLeaveServlet(hs).register(http_server)
+ GroupSelfJoinServlet(hs).register(http_server)
+ GroupSelfAcceptInviteServlet(hs).register(http_server)
+ GroupsForUserServlet(hs).register(http_server)
+ GroupCategoryServlet(hs).register(http_server)
+ GroupCategoriesServlet(hs).register(http_server)
+ GroupSummaryRoomsCatServlet(hs).register(http_server)
+ GroupRoleServlet(hs).register(http_server)
+ GroupRolesServlet(hs).register(http_server)
+ GroupSelfUpdatePublicityServlet(hs).register(http_server)
+ GroupSummaryUsersRoleServlet(hs).register(http_server)
+ PublicisedGroupsForUserServlet(hs).register(http_server)
+ PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1421c18152..d9a8cdbbb5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,8 +17,10 @@
from twisted.internet import defer
import synapse
+import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
+from synapse.types import RoomID, RoomAlias
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
@@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
+ self.room_member_handler = hs.get_handlers().room_member_handler
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
generate_token=False,
)
+ # auto-join the user to any rooms we're supposed to dump them into
+ fake_requester = synapse.types.create_requester(registered_user_id)
+ for r in self.hs.config.auto_join_rooms:
+ try:
+ yield self._join_user_to_room(fake_requester, r)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
@@ -373,6 +384,29 @@ class RegisterRestServlet(RestServlet):
return 200, {}
@defer.inlineCallbacks
+ def _join_user_to_room(self, requester, room_identifier):
+ room_id = None
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = (
+ yield self.room_member_handler.lookup_room_alias(room_alias)
+ )
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(400, "%s was not legal room ID or room alias" % (
+ room_identifier,
+ ))
+
+ yield self.room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ action="join",
+ )
+
+ @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token, body):
user_id = yield self.registration_handler.appservice_register(
username, as_token
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 978af9c280..a1e0e53b33 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d92b7ff337..d5164e47e0 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -14,78 +14,200 @@
# limitations under the License.
import os
+import re
+import functools
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
+
+
+def _wrap_in_base_path(func):
+ """Takes a function that returns a relative path and turns it into an
+ absolute path based on the location of the primary media store
+ """
+ @functools.wraps(func)
+ def _wrapped(self, *args, **kwargs):
+ path = func(self, *args, **kwargs)
+ return os.path.join(self.base_path, path)
+
+ return _wrapped
class MediaFilePaths(object):
+ """Describes where files are stored on disk.
- def __init__(self, base_path):
- self.base_path = base_path
+ Most of the functions have a `*_rel` variant which returns a file path that
+ is relative to the base media store path. This is mainly used when we want
+ to write to the backup media store (when one is configured)
+ """
- def default_thumbnail(self, default_top_level, default_sub_type, width,
- height, content_type, method):
+ def __init__(self, primary_base_path):
+ self.base_path = primary_base_path
+
+ def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
+ height, content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "default_thumbnails", default_top_level,
+ "default_thumbnails", default_top_level,
default_sub_type, file_name
)
- def local_media_filepath(self, media_id):
+ default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
+
+ def local_media_filepath_rel(self, media_id):
return os.path.join(
- self.base_path, "local_content",
+ "local_content",
media_id[0:2], media_id[2:4], media_id[4:]
)
- def local_media_thumbnail(self, media_id, width, height, content_type,
- method):
+ local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+
+ def local_media_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
- self.base_path, "local_thumbnails",
+ "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
- def remote_media_filepath(self, server_name, file_id):
+ local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+
+ def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
- self.base_path, "remote_content", server_name,
+ "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:]
)
- def remote_media_thumbnail(self, server_name, file_id, width, height,
- content_type, method):
+ remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+
+ def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
+ content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
- self.base_path, "remote_thumbnail", server_name,
+ "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
+ remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)
- def url_cache_filepath(self, media_id):
- return os.path.join(
- self.base_path, "url_cache",
- media_id[0:2], media_id[2:4], media_id[4:]
- )
+ def url_cache_filepath_rel(self, media_id):
+ if NEW_FORMAT_ID_RE.match(media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ return os.path.join(
+ "url_cache",
+ media_id[:10], media_id[11:]
+ )
+ else:
+ return os.path.join(
+ "url_cache",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+
+ def url_cache_filepath_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id file"
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2],
+ ),
+ ]
+
+ def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
+ method):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
- def url_cache_thumbnail(self, media_id, width, height, content_type,
- method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
- return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
- )
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ file_name
+ )
+ else:
+ return os.path.join(
+ "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ file_name
+ )
+
+ url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+
+ def url_cache_thumbnail_directory(self, media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id thumbnails"
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2],
+ ),
+ ]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0ea1248ce6..6b50b45b1f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.retryutils import NotRetryingDestination
import os
@@ -59,7 +59,14 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+ self.primary_base_path = hs.config.media_store_path
+ self.filepaths = MediaFilePaths(self.primary_base_path)
+
+ self.backup_base_path = hs.config.backup_media_store_path
+
+ self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
+
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
@@ -87,18 +94,86 @@ class MediaRepository(object):
if not os.path.exists(dirname):
os.makedirs(dirname)
+ @staticmethod
+ def _write_file_synchronously(source, fname):
+ """Write `source` to the path `fname` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object to be written
+ fname (str): Path to write to
+ """
+ MediaRepository._makedirs(fname)
+ source.seek(0) # Ensure we read from the start of the file
+ with open(fname, "wb") as f:
+ shutil.copyfileobj(source, f)
+
+ @defer.inlineCallbacks
+ def write_to_file_and_backup(self, source, path):
+ """Write `source` to the on disk media store, and also the backup store
+ if configured.
+
+ Args:
+ source: A file like object that should be written
+ path (str): Relative path to write file to
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+ fname = os.path.join(self.primary_base_path, path)
+
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ self._write_file_synchronously, source, fname,
+ ))
+
+ # Write to backup repository
+ yield self.copy_to_backup(path)
+
+ defer.returnValue(fname)
+
+ @defer.inlineCallbacks
+ def copy_to_backup(self, path):
+ """Copy a file from the primary to backup media store, if configured.
+
+ Args:
+ path(str): Relative path to write file to
+ """
+ if self.backup_base_path:
+ primary_fname = os.path.join(self.primary_base_path, path)
+ backup_fname = os.path.join(self.backup_base_path, path)
+
+ # We can either wait for successful writing to the backup repository
+ # or write in the background and immediately return
+ if self.synchronous_backup_media_store:
+ yield make_deferred_yieldable(threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ ))
+ else:
+ preserve_fn(threads.deferToThread)(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
+ """Store uploaded content for a local user and return the mxc URL
+
+ Args:
+ media_type(str): The content type of the file
+ upload_name(str): The name of the file
+ content: A file like object that is the content to store
+ content_length(int): The length of the content
+ auth_user(str): The user_id of the uploader
+
+ Returns:
+ Deferred[str]: The mxc url of the stored content
+ """
media_id = random_string(24)
- fname = self.filepaths.local_media_filepath(media_id)
- self._makedirs(fname)
-
- # This shouldn't block for very long because the content will have
- # already been uploaded at this point.
- with open(fname, "wb") as f:
- f.write(content)
+ fname = yield self.write_to_file_and_backup(
+ content, self.filepaths.local_media_filepath_rel(media_id)
+ )
logger.info("Stored local media in file %r", fname)
@@ -115,7 +190,7 @@ class MediaRepository(object):
"media_length": content_length,
}
- yield self._generate_local_thumbnails(media_id, media_info)
+ yield self._generate_thumbnails(None, media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@@ -148,9 +223,10 @@ class MediaRepository(object):
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
- fname = self.filepaths.remote_media_filepath(
+ fpath = self.filepaths.remote_media_filepath_rel(
server_name, file_id
)
+ fname = os.path.join(self.primary_base_path, fpath)
self._makedirs(fname)
try:
@@ -192,6 +268,8 @@ class MediaRepository(object):
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
+ yield self.copy_to_backup(fpath)
+
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
@@ -244,7 +322,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_remote_thumbnails(
+ yield self._generate_thumbnails(
server_name, media_id, media_info
)
@@ -253,9 +331,8 @@ class MediaRepository(object):
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+ def _generate_thumbnail(self, thumbnailer, t_width, t_height,
t_method, t_type):
- thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -267,72 +344,105 @@ class MediaRepository(object):
return
if t_method == "crop":
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
- t_len = None
+ t_byte_source = None
- return t_len
+ return t_byte_source
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
-
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ output_path = yield self.write_to_file_and_backup(
+ t_byte_source,
+ self.filepaths.local_media_thumbnail_rel(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
-
- t_len = yield preserve_context_over_fn(
- threads.deferToThread,
+ thumbnailer = Thumbnailer(input_path)
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
self._generate_thumbnail,
- input_path, t_path, t_width, t_height, t_method, t_type
- )
+ thumbnailer, t_width, t_height, t_method, t_type
+ ))
+
+ if t_byte_source:
+ try:
+ output_path = yield self.write_to_file_and_backup(
+ t_byte_source,
+ self.filepaths.remote_media_thumbnail_rel(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ )
+ finally:
+ t_byte_source.close()
+
+ logger.info("Stored thumbnail in file %r", output_path)
+
+ t_len = os.path.getsize(output_path)
- if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(t_path)
+ defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
+ def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+ """Generate and store thumbnails for an image.
+
+ Args:
+ server_name(str|None): The server name if remote media, else None if local
+ media_id(str)
+ media_info(dict)
+ url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+ used exclusively by the url previewer
+
+ Returns:
+ Deferred[dict]: Dict with "width" and "height" keys of original image
+ """
media_type = media_info["media_type"]
+ file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
- if url_cache:
+ if server_name:
+ input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+ elif url_cache:
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id)
@@ -348,135 +458,72 @@ class MediaRepository(object):
)
return
- local_thumbnails = []
-
- def generate_thumbnails():
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- if url_cache:
- t_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+ # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+ # they have the same dimensions of a scaled one.
+ thumbnails = {}
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "crop":
+ thumbnails.setdefault((r_width, r_height, r_type), r_method)
+ elif r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ t_width = min(m_width, t_width)
+ t_height = min(m_height, t_height)
+ thumbnails[(t_width, t_height, r_type)] = r_method
+
+ # Now we generate the thumbnails for each dimension, store it
+ for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
+ # Work out the correct file name for thumbnail
+ if server_name:
+ file_path = self.filepaths.remote_media_thumbnail_rel(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ elif url_cache:
+ file_path = self.filepaths.url_cache_thumbnail_rel(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ else:
+ file_path = self.filepaths.local_media_thumbnail_rel(
+ media_id, t_width, t_height, t_type, t_method
+ )
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ # Generate the thumbnail
+ if t_method == "crop":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.crop,
+ t_width, t_height, t_type,
))
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- if url_cache:
- t_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- local_thumbnails.append((
- media_id, t_width, t_height, t_type, t_method, t_len
+ elif t_method == "scale":
+ t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ thumbnailer.scale,
+ t_width, t_height, t_type,
))
+ else:
+ logger.error("Unrecognized method: %r", t_method)
+ continue
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for l in local_thumbnails:
- yield self.store.store_local_thumbnail(*l)
-
- defer.returnValue({
- "width": m_width,
- "height": m_height,
- })
-
- @defer.inlineCallbacks
- def _generate_remote_thumbnails(self, server_name, media_id, media_info):
- media_type = media_info["media_type"]
- file_id = media_info["filesystem_id"]
- requirements = self._get_thumbnail_requirements(media_type)
- if not requirements:
- return
+ if not t_byte_source:
+ continue
- remote_thumbnails = []
+ try:
+ # Write to disk
+ output_path = yield self.write_to_file_and_backup(
+ t_byte_source, file_path,
+ )
+ finally:
+ t_byte_source.close()
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- thumbnailer = Thumbnailer(input_path)
- m_width = thumbnailer.width
- m_height = thumbnailer.height
+ t_len = os.path.getsize(output_path)
- def generate_thumbnails():
- if m_width * m_height >= self.max_image_pixels:
- logger.info(
- "Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
- )
- return
-
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
- ))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
-
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
+ # Write to database
+ if server_name:
+ yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
- ])
-
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method
)
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- remote_thumbnails.append([
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
- ])
-
- yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
- for r in remote_thumbnails:
- yield self.store.store_remote_media_thumbnail(*r)
+ else:
+ yield self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
defer.returnValue({
"width": m_width,
@@ -497,6 +544,8 @@ class MediaRepository(object):
logger.info("Deleting: %r", key)
+ # TODO: Should we delete from the backup store
+
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b81a336c5d..2a3e37fdf4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -36,6 +36,9 @@ import cgi
import ujson as json
import urlparse
import itertools
+import datetime
+import errno
+import shutil
import logging
logger = logging.getLogger(__name__)
@@ -56,6 +59,7 @@ class PreviewUrlResource(Resource):
self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
+ self.primary_base_path = media_repo.primary_base_path
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -70,6 +74,10 @@ class PreviewUrlResource(Resource):
self.downloads = {}
+ self._cleaner_loop = self.clock.looping_call(
+ self._expire_url_cache_data, 10 * 1000
+ )
+
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@@ -130,7 +138,7 @@ class PreviewUrlResource(Resource):
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result and
- cache_result["download_ts"] + cache_result["expires"] > ts and
+ cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
respond_with_json_bytes(
@@ -163,8 +171,8 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
- dims = yield self.media_repo._generate_local_thumbnails(
- media_info['filesystem_id'], media_info, url_cache=True,
+ dims = yield self.media_repo._generate_thumbnails(
+ None, media_info['filesystem_id'], media_info, url_cache=True,
)
og = {
@@ -209,8 +217,8 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
- dims = yield self.media_repo._generate_local_thumbnails(
- image_info['filesystem_id'], image_info, url_cache=True,
+ dims = yield self.media_repo._generate_thumbnails(
+ None, image_info['filesystem_id'], image_info, url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -239,7 +247,7 @@ class PreviewUrlResource(Resource):
url,
media_info["response_code"],
media_info["etag"],
- media_info["expires"],
+ media_info["expires"] + media_info["created_ts"],
json.dumps(og),
media_info["filesystem_id"],
media_info["created_ts"],
@@ -253,10 +261,10 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- # XXX: horrible duplication with base_resource's _download_remote_file()
- file_id = random_string(24)
+ file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fname = self.filepaths.url_cache_filepath(file_id)
+ fpath = self.filepaths.url_cache_filepath_rel(file_id)
+ fname = os.path.join(self.primary_base_path, fpath)
self.media_repo._makedirs(fname)
try:
@@ -267,6 +275,8 @@ class PreviewUrlResource(Resource):
)
# FIXME: pass through 404s and other error messages nicely
+ yield self.media_repo.copy_to_backup(fpath)
+
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
@@ -328,6 +338,91 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
+ @defer.inlineCallbacks
+ def _expire_url_cache_data(self):
+ """Clean up expired url cache content, media and thumbnails.
+ """
+
+ # TODO: Delete from backup media store
+
+ now = self.clock.time_msec()
+
+ # First we delete expired url cache entries
+ media_ids = yield self.store.get_expired_url_cache(now)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ yield self.store.delete_url_cache(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d entries from url cache", len(removed_media))
+
+ # Now we delete old images associated with the url cache.
+ # These may be cached for a bit on the client (i.e., they
+ # may have a room open with a preview url thing open).
+ # So we wait a couple of days before deleting, just in case.
+ expire_before = now - 2 * 24 * 60 * 60 * 1000
+ media_ids = yield self.store.get_url_cache_media_before(expire_before)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
+ try:
+ shutil.rmtree(thumbnail_dir)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ yield self.store.delete_url_cache_media(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d media from url cache", len(removed_media))
+
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3868d4f65f..e1ee535b9a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -50,12 +50,16 @@ class Thumbnailer(object):
else:
return ((max_height * self.width) // self.height, max_height)
- def scale(self, output_path, width, height, output_type):
- """Rescales the image to the given dimensions"""
+ def scale(self, width, height, output_type):
+ """Rescales the image to the given dimensions.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
+ """
scaled = self.image.resize((width, height), Image.ANTIALIAS)
- return self.save_image(scaled, output_type, output_path)
+ return self._encode_image(scaled, output_type)
- def crop(self, output_path, width, height, output_type):
+ def crop(self, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -65,6 +69,9 @@ class Thumbnailer(object):
Args:
max_width: The largest possible width.
max_height: The larget possible height.
+
+ Returns:
+ BytesIO: the bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
@@ -82,13 +89,9 @@ class Thumbnailer(object):
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
- return self.save_image(cropped, output_type, output_path)
+ return self._encode_image(cropped, output_type)
- def save_image(self, output_image, output_type, output_path):
+ def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
- output_bytes = output_bytes_io.getvalue()
- with open(output_path, "wb") as output_file:
- output_file.write(output_bytes)
- logger.info("Stored thumbnail in file %r", output_path)
- return len(output_bytes)
+ return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 4ab33f73bf..f6f498cdc5 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -93,7 +93,7 @@ class UploadResource(Resource):
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
- media_type, upload_name, request.content.read(),
+ media_type, upload_name, request.content,
content_length, requester.user
)
diff --git a/synapse/server.py b/synapse/server.py
index a38e5179e0..10e3e9a4f1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
+from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient
@@ -50,6 +51,10 @@ from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.profile import ProfileHandler
+from synapse.groups.groups_server import GroupsServerHandler
+from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@@ -111,6 +116,7 @@ class HomeServer(object):
'application_service_scheduler',
'application_service_handler',
'device_message_handler',
+ 'profile_handler',
'notifier',
'distributor',
'client_resource',
@@ -139,6 +145,11 @@ class HomeServer(object):
'read_marker_handler',
'action_generator',
'user_directory_handler',
+ 'groups_local_handler',
+ 'groups_server_handler',
+ 'groups_attestation_signing',
+ 'groups_attestation_renewer',
+ 'spam_checker',
]
def __init__(self, hostname, **kwargs):
@@ -251,6 +262,9 @@ class HomeServer(object):
def build_initial_sync_handler(self):
return InitialSyncHandler(self)
+ def build_profile_handler(self):
+ return ProfileHandler(self)
+
def build_event_sources(self):
return EventSources(self)
@@ -309,6 +323,21 @@ class HomeServer(object):
def build_user_directory_handler(self):
return UserDirectoyHandler(self)
+ def build_groups_local_handler(self):
+ return GroupsLocalHandler(self)
+
+ def build_groups_server_handler(self):
+ return GroupsServerHandler(self)
+
+ def build_groups_attestation_signing(self):
+ return GroupAttestationSigning(self)
+
+ def build_groups_attestation_renewer(self):
+ return GroupAttestionRenewer(self)
+
+ def build_spam_checker(self):
+ return SpamChecker(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9570df5537..e8c0386b7f 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,4 +1,6 @@
import synapse.api.auth
+import synapse.federation.transaction_queue
+import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
@@ -27,3 +29,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
+
+ def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
+ pass
+
+ def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+ pass
diff --git a/synapse/state.py b/synapse/state.py
index 390799fbd5..dcdcdef65e 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -288,6 +288,9 @@ class StateHandler(object):
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
+ # map from state group id to the state in that state group (where
+ # 'state' is a map from state key to event id)
+ # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids
)
@@ -320,11 +323,15 @@ class StateHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
+ # build a map from state key to the event_ids which set that state.
+ # dict[(str, str), set[str])
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
+ # build a map from state key to the event_ids which set that state,
+ # including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
for k, v in state.items()
@@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
logger.info("Asking for %d conflicted events", len(needed_events))
+ # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+ # the state events which are in conflict.
state_map = yield state_map_factory(needed_events)
+ # get the ids of the auth events which allow us to authenticate the
+ # conflicted state, picking only from the unconflicting state.
+ #
+ # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b92472df33..594566eb38 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
from .event_push_actions import EventPushActionsStore
from .deviceinbox import DeviceInboxStore
-
+from .group_server import GroupServerStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
@@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
+ GroupServerStore,
):
def __init__(self, db_conn, hs):
@@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
+ self._group_updates_id_gen = StreamIdGenerator(
+ db_conn, "local_group_updates", "stream_id",
+ )
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
@@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=curr_state_delta_prefill,
)
+ _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+ db_conn, "local_group_updates",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._group_updates_id_gen.get_current_token(),
+ limit=1000,
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache", min_group_updates_id,
+ prefilled_cache=_group_updates_prefill,
+ )
+
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6f54036d67..5124a833a5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -743,6 +743,33 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
+ def _simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc,
+ self._simple_update_txn,
+ table, keyvalues, updatevalues,
+ )
+
+ @staticmethod
+ def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ return txn.rowcount
+
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
@@ -768,27 +795,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- @staticmethod
- def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
- )
+ @classmethod
+ def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
- if txn.rowcount == 0:
+ if rowcount == 0:
raise StoreError(404, "No row found")
- if txn.rowcount > 1:
+ if rowcount > 1:
raise StoreError(500, "More than one row matched")
@staticmethod
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7002b3752e..637640ec2a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+ preserve_fn, PreserveLoggingContext, make_deferred_yieldable
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
@@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
Args:
room_id (str):
events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
end_item = queue[-1]
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
@@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
def handle_queue(self, room_id, per_item_callback):
"""Attempts to handle the queue for a room if not already being handled.
- The given callback will be invoked with for each item in the queue,1
+ The given callback will be invoked with for each item in the queue,
of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return
value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferres as well.
+ exceptions will be passed to the deferreds as well.
This function should therefore be called whenever anything is added
to the queue.
@@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
deferreds = []
for room_id, evs_ctxs in partitioned.iteritems():
- d = preserve_fn(self._event_persist_queue.add_to_queue)(
+ d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs,
backfilled=backfilled,
)
@@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- return preserve_context_over_deferred(
+ return make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
@@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
self._maybe_start_persisting(event.room_id)
- yield preserve_context_over_deferred(deferred)
+ yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore):
self._invalidate_cache_and_stream(
txn, self.is_host_joined, (room_id, host)
)
+ self._invalidate_cache_and_stream(
+ txn, self.was_host_joined, (room_id, host)
+ )
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
@@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore):
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
- res = yield preserve_context_over_deferred(defer.gatherResults(
+ res = yield make_deferred_yieldable(defer.gatherResults(
[
preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
new file mode 100644
index 0000000000..9e63db5c6c
--- /dev/null
+++ b/synapse/storage/group_server.py
@@ -0,0 +1,1199 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+
+from ._base import SQLBaseStore
+
+import ujson as json
+
+
+# The category ID for the "default" category. We don't store as null in the
+# database to avoid the fun of null != null
+_DEFAULT_CATEGORY_ID = ""
+_DEFAULT_ROLE_ID = ""
+
+
+class GroupServerStore(SQLBaseStore):
+ def get_group(self, group_id):
+ return self._simple_select_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("name", "short_description", "long_description", "avatar_url",),
+ allow_none=True,
+ desc="is_user_in_group",
+ )
+
+ def get_users_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_users",
+ keyvalues=keyvalues,
+ retcols=("user_id", "is_public",),
+ desc="get_users_in_group",
+ )
+
+ def get_invited_users_in_group(self, group_id):
+ # TODO: Pagination
+
+ return self._simple_select_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcol="user_id",
+ desc="get_invited_users_in_group",
+ )
+
+ def get_rooms_in_group(self, group_id, include_private=False):
+ # TODO: Pagination
+
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ return self._simple_select_list(
+ table="group_rooms",
+ keyvalues=keyvalues,
+ retcols=("room_id", "is_public",),
+ desc="get_rooms_in_group",
+ )
+
+ def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ """Get the rooms and categories that should be included in a summary request
+
+ Returns ([rooms], [categories])
+ """
+ def _get_rooms_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT room_id, is_public, category_id, room_order
+ FROM group_summary_rooms
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ rooms = [
+ {
+ "room_id": row[0],
+ "is_public": row[1],
+ "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT category_id, is_public, profile, cat_order
+ FROM group_summary_room_categories
+ INNER JOIN group_room_categories USING (group_id, category_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ categories = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return rooms, categories
+ return self.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
+
+ def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
+ return self.runInteraction(
+ "add_room_to_summary", self._add_room_to_summary_txn,
+ group_id, room_id, category_id, order, is_public,
+ )
+
+ def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
+ is_public):
+ """Add (or update) room's entry in summary.
+
+ Args:
+ group_id (str)
+ room_id (str)
+ category_id (str): If not None then adds the category to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the room at that position, e.g.
+ an order of 1 will put the room first. Otherwise, the room gets
+ added to the end.
+ """
+ room_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ retcol="room_id",
+ allow_none=True,
+ )
+ if not room_in_group:
+ raise SynapseError(400, "room not in group")
+
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+ else:
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ raise SynapseError(400, "Category doesn't exist")
+
+ # TODO: Check category is part of summary already
+ cat_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not cat_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_room_categories
+ (group_id, category_id, cat_order)
+ SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
+ FROM group_summary_room_categories
+ WHERE group_id = ? AND category_id = ?
+ """, (group_id, category_id, group_id, category_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ "category_id": category_id,
+ },
+ retcols=("room_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other room orders that come after the given order
+ sql = """
+ UPDATE group_summary_rooms SET room_order = room_order + 1
+ WHERE group_id = ? AND category_id = ? AND room_order >= ?
+ """
+ txn.execute(sql, (group_id, category_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
+ WHERE group_id = ? AND category_id = ?
+ """
+ txn.execute(sql, (group_id, category_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["room_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_rooms",
+ values={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ "room_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_room_from_summary(self, group_id, room_id, category_id):
+ if category_id is None:
+ category_id = _DEFAULT_CATEGORY_ID
+
+ return self._simple_delete(
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ "room_id": room_id,
+ },
+ desc="remove_room_from_summary",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("category_id", "is_public", "profile"),
+ desc="get_group_categories",
+ )
+
+ defer.returnValue({
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, category_id):
+ category = yield self._simple_select_one(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_category",
+ )
+
+ category["profile"] = json.loads(category["profile"])
+
+ defer.returnValue(category)
+
+ def upsert_group_category(self, group_id, category_id, profile, is_public):
+ """Add/update room category for group
+ """
+ insertion_values = {}
+ update_values = {"category_id": category_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_category",
+ )
+
+ def remove_group_category(self, group_id, category_id):
+ return self._simple_delete(
+ table="group_room_categories",
+ keyvalues={
+ "group_id": group_id,
+ "category_id": category_id,
+ },
+ desc="remove_group_category",
+ )
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id):
+ rows = yield self._simple_select_list(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ },
+ retcols=("role_id", "is_public", "profile"),
+ desc="get_group_roles",
+ )
+
+ defer.returnValue({
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ })
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, role_id):
+ role = yield self._simple_select_one(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcols=("is_public", "profile"),
+ desc="get_group_role",
+ )
+
+ role["profile"] = json.loads(role["profile"])
+
+ defer.returnValue(role)
+
+ def upsert_group_role(self, group_id, role_id, profile, is_public):
+ """Add/remove user role
+ """
+ insertion_values = {}
+ update_values = {"role_id": role_id} # This cannot be empty
+
+ if profile is None:
+ insertion_values["profile"] = "{}"
+ else:
+ update_values["profile"] = json.dumps(profile)
+
+ if is_public is None:
+ insertion_values["is_public"] = True
+ else:
+ update_values["is_public"] = is_public
+
+ return self._simple_upsert(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ values=update_values,
+ insertion_values=insertion_values,
+ desc="upsert_group_role",
+ )
+
+ def remove_group_role(self, group_id, role_id):
+ return self._simple_delete(
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ desc="remove_group_role",
+ )
+
+ def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
+ return self.runInteraction(
+ "add_user_to_summary", self._add_user_to_summary_txn,
+ group_id, user_id, role_id, order, is_public,
+ )
+
+ def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
+ is_public):
+ """Add (or update) user's entry in summary.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ role_id (str): If not None then adds the role to the end of
+ the summary if its not already there. [Optional]
+ order (int): If not None inserts the user at that position, e.g.
+ an order of 1 will put the user first. Otherwise, the user gets
+ added to the end.
+ """
+ user_in_group = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+ if not user_in_group:
+ raise SynapseError(400, "user not in group")
+
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+ else:
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ raise SynapseError(400, "Role doesn't exist")
+
+ # TODO: Check role is part of the summary already
+ role_exists = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_summary_roles",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ },
+ retcol="group_id",
+ allow_none=True,
+ )
+ if not role_exists:
+ # If not, add it with an order larger than all others
+ txn.execute("""
+ INSERT INTO group_summary_roles
+ (group_id, role_id, role_order)
+ SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
+ FROM group_summary_roles
+ WHERE group_id = ? AND role_id = ?
+ """, (group_id, role_id, group_id, role_id))
+
+ existing = self._simple_select_one_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ "role_id": role_id,
+ },
+ retcols=("user_order", "is_public",),
+ allow_none=True,
+ )
+
+ if order is not None:
+ # Shuffle other users orders that come after the given order
+ sql = """
+ UPDATE group_summary_users SET user_order = user_order + 1
+ WHERE group_id = ? AND role_id = ? AND user_order >= ?
+ """
+ txn.execute(sql, (group_id, role_id, order,))
+ elif not existing:
+ sql = """
+ SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
+ WHERE group_id = ? AND role_id = ?
+ """
+ txn.execute(sql, (group_id, role_id,))
+ order, = txn.fetchone()
+
+ if existing:
+ to_update = {}
+ if order is not None:
+ to_update["user_order"] = order
+ if is_public is not None:
+ to_update["is_public"] = is_public
+ self._simple_update_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ values=to_update,
+ )
+ else:
+ if is_public is None:
+ is_public = True
+
+ self._simple_insert_txn(
+ txn,
+ table="group_summary_users",
+ values={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ "user_order": order,
+ "is_public": is_public,
+ },
+ )
+
+ def remove_user_from_summary(self, group_id, user_id, role_id):
+ if role_id is None:
+ role_id = _DEFAULT_ROLE_ID
+
+ return self._simple_delete(
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "role_id": role_id,
+ "user_id": user_id,
+ },
+ desc="remove_user_from_summary",
+ )
+
+ def get_users_for_summary_by_role(self, group_id, include_private=False):
+ """Get the users and roles that should be included in a summary request
+
+ Returns ([users], [roles])
+ """
+ def _get_users_for_summary_txn(txn):
+ keyvalues = {
+ "group_id": group_id,
+ }
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT user_id, is_public, role_id, user_order
+ FROM group_summary_users
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ users = [
+ {
+ "user_id": row[0],
+ "is_public": row[1],
+ "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT role_id, is_public, profile, role_order
+ FROM group_summary_roles
+ INNER JOIN group_roles USING (group_id, role_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ roles = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return users, roles
+ return self.runInteraction(
+ "get_users_for_summary_by_role", _get_users_for_summary_txn
+ )
+
+ def is_user_in_group(self, user_id, group_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ desc="is_user_in_group",
+ ).addCallback(lambda r: bool(r))
+
+ def is_user_admin_in_group(self, group_id, user_id):
+ return self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="is_admin",
+ allow_none=True,
+ desc="is_user_admin_in_group",
+ )
+
+ def add_group_invite(self, group_id, user_id):
+ """Record that the group server has invited a user
+ """
+ return self._simple_insert(
+ table="group_invites",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ desc="add_group_invite",
+ )
+
+ def is_user_invited_to_local_group(self, group_id, user_id):
+ """Has the group server invited a user?
+ """
+ return self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ desc="is_user_invited_to_local_group",
+ allow_none=True,
+ )
+
+ def get_users_membership_info_in_group(self, group_id, user_id):
+ """Get a dict describing the membership of a user in a group.
+
+ Example if joined:
+
+ {
+ "membership": "join",
+ "is_public": True,
+ "is_privileged": False,
+ }
+
+ Returns an empty dict if the user is not join/invite/etc
+ """
+ def _get_users_membership_in_group_txn(txn):
+ row = self._simple_select_one_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("is_admin", "is_public"),
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "join",
+ "is_public": row["is_public"],
+ "is_privileged": row["is_admin"],
+ }
+
+ row = self._simple_select_one_onecol_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcol="user_id",
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "invite",
+ }
+
+ return {}
+
+ return self.runInteraction(
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+ )
+
+ def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
+ local_attestation=None, remote_attestation=None):
+ """Add a user to the group server.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ is_admin (bool)
+ is_public (bool)
+ local_attestation (dict): The attestation the GS created to give
+ to the remote server. Optional if the user and group are on the
+ same server
+ remote_attestation (dict): The attestation given to GS by remote
+ server. Optional if the user and group are on the same server
+ """
+ def _add_user_to_group_txn(txn):
+ self._simple_insert_txn(
+ txn,
+ table="group_users",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "is_public": is_public,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ },
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ },
+ )
+
+ return self.runInteraction(
+ "add_user_to_group", _add_user_to_group_txn
+ )
+
+ def remove_user_from_group(self, group_id, user_id):
+ def _remove_user_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_invites",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_users",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+ def add_room_to_group(self, group_id, room_id, is_public):
+ return self._simple_insert(
+ table="group_rooms",
+ values={
+ "group_id": group_id,
+ "room_id": room_id,
+ "is_public": is_public,
+ },
+ desc="add_room_to_group",
+ )
+
+ def remove_room_from_group(self, group_id, room_id):
+ def _remove_room_from_group_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="group_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+
+ self._simple_delete_txn(
+ txn,
+ table="group_summary_rooms",
+ keyvalues={
+ "group_id": group_id,
+ "room_id": room_id,
+ },
+ )
+ return self.runInteraction(
+ "remove_room_from_group", _remove_room_from_group_txn,
+ )
+
+ def get_publicised_groups_for_user(self, user_id):
+ """Get all groups a user is publicising
+ """
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ "is_publicised": True,
+ },
+ retcol="group_id",
+ desc="get_publicised_groups_for_user",
+ )
+
+ def update_group_publicity(self, group_id, user_id, publicise):
+ """Update whether the user is publicising their membership of the group
+ """
+ return self._simple_update_one(
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "is_publicised": publicise,
+ },
+ desc="update_group_publicity"
+ )
+
+ @defer.inlineCallbacks
+ def register_user_group_membership(self, group_id, user_id, membership,
+ is_admin=False, content={},
+ local_attestation=None,
+ remote_attestation=None,
+ is_publicised=False,
+ ):
+ """Registers that a local user is a member of a (local or remote) group.
+
+ Args:
+ group_id (str)
+ user_id (str)
+ membership (str)
+ is_admin (bool)
+ content (dict): Content of the membership, e.g. includes the inviter
+ if the user has been invited.
+ local_attestation (dict): If remote group then store the fact that we
+ have given out an attestation, else None.
+ remote_attestation (dict): If remote group then store the remote
+ attestation from the group, else None.
+ """
+ def _register_user_group_membership_txn(txn, next_id):
+ # TODO: Upsert?
+ self._simple_delete_txn(
+ txn,
+ table="local_group_membership",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ table="local_group_membership",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "is_admin": is_admin,
+ "membership": membership,
+ "is_publicised": is_publicised,
+ "content": json.dumps(content),
+ },
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="local_group_updates",
+ values={
+ "stream_id": next_id,
+ "group_id": group_id,
+ "user_id": user_id,
+ "type": "membership",
+ "content": json.dumps({"membership": membership, "content": content}),
+ }
+ )
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+ # TODO: Insert profile to ensure it comes down stream if its a join.
+
+ if membership == "join":
+ if local_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_renewals",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": local_attestation["valid_until_ms"],
+ }
+ )
+ if remote_attestation:
+ self._simple_insert_txn(
+ txn,
+ table="group_attestations_remote",
+ values={
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": remote_attestation["valid_until_ms"],
+ "attestation_json": json.dumps(remote_attestation),
+ }
+ )
+ else:
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+ self._simple_delete_txn(
+ txn,
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ )
+
+ return next_id
+
+ with self._group_updates_id_gen.get_next() as next_id:
+ res = yield self.runInteraction(
+ "register_user_group_membership",
+ _register_user_group_membership_txn, next_id,
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def create_group(self, group_id, user_id, name, avatar_url, short_description,
+ long_description,):
+ yield self._simple_insert(
+ table="groups",
+ values={
+ "group_id": group_id,
+ "name": name,
+ "avatar_url": avatar_url,
+ "short_description": short_description,
+ "long_description": long_description,
+ },
+ desc="create_group",
+ )
+
+ @defer.inlineCallbacks
+ def update_group_profile(self, group_id, profile,):
+ yield self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues=profile,
+ desc="update_group_profile",
+ )
+
+ def get_attestations_need_renewals(self, valid_until_ms):
+ """Get all attestations that need to be renewed until givent time
+ """
+ def _get_attestations_need_renewals_txn(txn):
+ sql = """
+ SELECT group_id, user_id FROM group_attestations_renewals
+ WHERE valid_until_ms <= ?
+ """
+ txn.execute(sql, (valid_until_ms,))
+ return self.cursor_to_dict(txn)
+ return self.runInteraction(
+ "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+ )
+
+ def update_attestation_renewal(self, group_id, user_id, attestation):
+ """Update an attestation that we have renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_renewals",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ },
+ desc="update_attestation_renewal",
+ )
+
+ def update_remote_attestion(self, group_id, user_id, attestation):
+ """Update an attestation that a remote has renewed
+ """
+ return self._simple_update_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ updatevalues={
+ "valid_until_ms": attestation["valid_until_ms"],
+ "attestation_json": json.dumps(attestation)
+ },
+ desc="update_remote_attestion",
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_attestation(self, group_id, user_id):
+ """Get the attestation that proves the remote agrees that the user is
+ in the group.
+ """
+ row = yield self._simple_select_one(
+ table="group_attestations_remote",
+ keyvalues={
+ "group_id": group_id,
+ "user_id": user_id,
+ },
+ retcols=("valid_until_ms", "attestation_json"),
+ desc="get_remote_attestation",
+ allow_none=True,
+ )
+
+ now = int(self._clock.time_msec())
+ if row and now < row["valid_until_ms"]:
+ defer.returnValue(json.loads(row["attestation_json"]))
+
+ defer.returnValue(None)
+
+ def get_joined_groups(self, user_id):
+ return self._simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={
+ "user_id": user_id,
+ "membership": "join",
+ },
+ retcol="group_id",
+ desc="get_joined_groups",
+ )
+
+ def get_all_groups_for_user(self, user_id, now_token):
+ def _get_all_groups_for_user_txn(txn):
+ sql = """
+ SELECT group_id, type, membership, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND membership != 'leave'
+ AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, now_token,))
+ return [
+ {
+ "group_id": row[0],
+ "type": row[1],
+ "membership": row[2],
+ "content": json.loads(row[3]),
+ }
+ for row in txn
+ ]
+ return self.runInteraction(
+ "get_all_groups_for_user", _get_all_groups_for_user_txn,
+ )
+
+ def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_entity_changed(
+ user_id, from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_groups_changes_for_user_txn(txn):
+ sql = """
+ SELECT group_id, membership, type, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, from_token, to_token,))
+ return [{
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ } for group_id, membership, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+ )
+
+ def get_all_groups_changes(self, from_token, to_token, limit):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+ from_token,
+ )
+ if not has_changed:
+ return []
+
+ def _get_all_groups_changes_txn(txn):
+ sql = """
+ SELECT stream_id, group_id, user_id, type, content
+ FROM local_group_updates
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit,))
+ return [(
+ stream_id,
+ group_id,
+ user_id,
+ gtype,
+ json.loads(content_json),
+ ) for stream_id, group_id, user_id, gtype, content_json in txn]
+ return self.runInteraction(
+ "get_all_groups_changes", _get_all_groups_changes_txn,
+ )
+
+ def get_group_stream_token(self):
+ return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 82bb61b811..7110a71279 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
@@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
@@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None
return dict(zip((
- 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+ 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row))
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
- def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+ def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts):
return self._simple_insert(
"local_media_repository_url_cache",
@@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url,
"response_code": response_code,
"etag": etag,
- "expires": expires,
+ "expires_ts": expires_ts,
"og": og,
"media_id": media_id,
"download_ts": download_ts,
@@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+ def get_expired_url_cache(self, now_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository_url_cache"
+ " WHERE expires_ts < ?"
+ " ORDER BY expires_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_expired_url_cache_txn(txn):
+ txn.execute(sql, (now_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+ def delete_url_cache(self, media_ids):
+ sql = (
+ "DELETE FROM local_media_repository_url_cache"
+ " WHERE media_id = ?"
+ )
+
+ def _delete_url_cache_txn(txn):
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+ def get_url_cache_media_before(self, before_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository"
+ " WHERE created_ts < ? AND url_cache IS NOT NULL"
+ " ORDER BY created_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_url_cache_media_before_txn(txn):
+ txn.execute(sql, (before_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction(
+ "get_url_cache_media_before", _get_url_cache_media_before_txn,
+ )
+
+ def delete_url_cache_media(self, media_ids):
+ def _delete_url_cache_media_txn(txn):
+ sql = (
+ "DELETE FROM local_media_repository"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ sql = (
+ "DELETE FROM local_media_repository_thumbnails"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction(
+ "delete_url_cache_media", _delete_url_cache_media_txn,
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 72b670b83b..ccaaabcfa0 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 43
+SCHEMA_VERSION = 45
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..beea3102fc 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from ._base import SQLBaseStore
@@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
updatevalues={"avatar_url": new_avatar_url},
desc="set_profile_avatar_url",
)
+
+ def get_from_remote_profile_cache(self, user_id):
+ return self._simple_select_one(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ retcols=("displayname", "avatar_url",),
+ allow_none=True,
+ desc="get_from_remote_profile_cache",
+ )
+
+ def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+ """Ensure we are caching the remote user's profiles.
+
+ This should only be called when `is_subscribed_remote_profile_for_user`
+ would return true for the user.
+ """
+ return self._simple_upsert(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="add_remote_profile_cache",
+ )
+
+ def update_remote_profile_cache(self, user_id, displayname, avatar_url):
+ return self._simple_update(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ values={
+ "displayname": displayname,
+ "avatar_url": avatar_url,
+ "last_check": self._clock.time_msec(),
+ },
+ desc="update_remote_profile_cache",
+ )
+
+ @defer.inlineCallbacks
+ def maybe_delete_remote_profile_cache(self, user_id):
+ """Check if we still care about the remote user's profile, and if we
+ don't then remove their profile from the cache
+ """
+ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+ if not subscribed:
+ yield self._simple_delete(
+ table="remote_profile_cache",
+ keyvalues={"user_id": user_id},
+ desc="delete_remote_profile_cache",
+ )
+
+ def get_remote_profile_cache_entries_that_expire(self, last_checked):
+ """Get all users who haven't been checked since `last_checked`
+ """
+ def _get_remote_profile_cache_entries_that_expire_txn(txn):
+ sql = """
+ SELECT user_id, displayname, avatar_url
+ FROM remote_profile_cache
+ WHERE last_check < ?
+ """
+
+ txn.execute(sql, (last_checked,))
+
+ return self.cursor_to_dict(txn)
+
+ return self.runInteraction(
+ "get_remote_profile_cache_entries_that_expire",
+ _get_remote_profile_cache_entries_that_expire_txn,
+ )
+
+ @defer.inlineCallbacks
+ def is_subscribed_remote_profile_for_user(self, user_id):
+ """Check whether we are interested in a remote user's profile.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
+
+ res = yield self._simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="should_update_remote_profile_cache_for_user",
+ )
+
+ if res:
+ defer.returnValue(True)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 457ca288d0..a0fc9a6867 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(True)
+ @cachedInlineCallbacks()
+ def was_host_joined(self, room_id, host):
+ """Check whether the server is or ever was in the room.
+
+ Args:
+ room_id (str)
+ host (str)
+
+ Returns:
+ Deferred: Resolves to True if the host is/was in the room, otherwise
+ False.
+ """
+ if '%' in host or '_' in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT user_id FROM room_memberships
+ WHERE room_id = ?
+ AND user_id LIKE ?
+ AND membership = 'join'
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ defer.returnValue(False)
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ defer.returnValue(True)
+
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..e2b775f038
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,38 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+ url TEXT,
+ response_code INTEGER,
+ etag TEXT,
+ expires_ts BIGINT,
+ og TEXT,
+ media_id TEXT,
+ download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+ SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql
new file mode 100644
index 0000000000..b2333848a0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/group_server.sql
@@ -0,0 +1,167 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups (
+ group_id TEXT NOT NULL,
+ name TEXT, -- the display name of the room
+ avatar_url TEXT,
+ short_description TEXT,
+ long_description TEXT
+);
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
+
+
+-- list of users the group server thinks are joined
+CREATE TABLE group_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
+);
+
+
+CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
+CREATE INDEX groups_users_u_idx ON group_users(user_id);
+
+-- list of users the group server thinks are invited
+CREATE TABLE group_invites (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL
+);
+
+CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
+CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
+
+
+CREATE TABLE group_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
+);
+
+CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
+CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
+
+
+-- Rooms to include in the summary
+CREATE TABLE group_summary_rooms (
+ group_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ room_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
+ UNIQUE (group_id, category_id, room_id, room_order),
+ CHECK (room_order > 0)
+);
+
+CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
+
+
+-- Categories to include in the summary
+CREATE TABLE group_summary_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ cat_order BIGINT NOT NULL,
+ UNIQUE (group_id, category_id, cat_order),
+ CHECK (cat_order > 0)
+);
+
+-- The categories in the group
+CREATE TABLE group_room_categories (
+ group_id TEXT NOT NULL,
+ category_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
+ UNIQUE (group_id, category_id)
+);
+
+-- The users to include in the group summary
+CREATE TABLE group_summary_users (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ user_order BIGINT NOT NULL,
+ is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
+);
+
+CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
+
+-- The roles to include in the group summary
+CREATE TABLE group_summary_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ role_order BIGINT NOT NULL,
+ UNIQUE (group_id, role_id, role_order),
+ CHECK (role_order > 0)
+);
+
+
+-- The roles in a groups
+CREATE TABLE group_roles (
+ group_id TEXT NOT NULL,
+ role_id TEXT NOT NULL,
+ profile TEXT NOT NULL,
+ is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
+ UNIQUE (group_id, role_id)
+);
+
+
+-- List of attestations we've given out and need to renew
+CREATE TABLE group_attestations_renewals (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL
+);
+
+CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
+CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
+CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
+
+
+-- List of attestations we've received from remotes and are interested in.
+CREATE TABLE group_attestations_remote (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ valid_until_ms BIGINT NOT NULL,
+ attestation_json TEXT NOT NULL
+);
+
+CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
+CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
+CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+-- The group membership for the HS's users
+CREATE TABLE local_group_membership (
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ is_admin BOOLEAN NOT NULL,
+ membership TEXT NOT NULL,
+ is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
+ content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+ stream_id BIGINT NOT NULL,
+ group_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ content TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql
new file mode 100644
index 0000000000..e5ddc84df0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/profile_cache.sql
@@ -0,0 +1,28 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A subset of remote users whose profiles we have cached.
+-- Whether a user is in this table or not is defined by the storage function
+-- `is_subscribed_remote_profile_for_user`
+CREATE TABLE remote_profile_cache (
+ user_id TEXT NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ last_check BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
+CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..f03ad99118 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -45,6 +45,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -65,6 +66,7 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
@@ -73,6 +75,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
+ groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@@ -93,5 +96,6 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
+ groups_key=groups_key,
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 111948540d..37d5fa7f9f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -156,6 +156,11 @@ class EventID(DomainSpecificString):
SIGIL = "$"
+class GroupID(DomainSpecificString):
+ """Structure representing a group ID."""
+ SIGIL = "+"
+
+
class StreamToken(
namedtuple("Token", (
"room_key",
@@ -166,6 +171,7 @@ class StreamToken(
"push_rules_key",
"to_device_key",
"device_list_key",
+ "groups_key",
))
):
_SEPARATOR = "_"
@@ -204,6 +210,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
+ or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1453faf0ef..a0a9039475 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
-from synapse.util import unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError
from contextlib import contextmanager
@@ -53,6 +53,11 @@ class ObservableDeferred(object):
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
+
+ NB that it does not attempt to do anything with logcontexts; in general
+ you should probably make_deferred_yieldable the deferreds
+ returned by `observe`, and ensure that the original deferred runs its
+ callbacks in the sentinel logcontext.
"""
__slots__ = ["_deferred", "_observers", "_result"]
@@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
- return preserve_context_over_deferred(defer.gatherResults([
+ return logcontext.make_deferred_yieldable(defer.gatherResults([
preserve_fn(_concurrently_execute_inner)()
for _ in xrange(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
@@ -203,7 +208,26 @@ class Linearizer(object):
except:
logger.exception("Unexpected exception in Linearizer")
- logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+ logger.info("Acquired linearizer lock %r for key %r", self.name,
+ key)
+
+ # if the code holding the lock completes synchronously, then it
+ # will recursively run the next claimant on the list. That can
+ # relatively rapidly lead to stack exhaustion. This is essentially
+ # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+ #
+ # In order to break the cycle, we add a cheeky sleep(0) here to
+ # ensure that we fall back to the reactor between each iteration.
+ #
+ # (There's no particular need for it to happen before we return
+ # the context manager, but it needs to happen while we hold the
+ # lock, and the context manager's exit code must be synchronous,
+ # so actually this is the only sensible place.
+ yield run_on_reactor()
+
+ else:
+ logger.info("Acquired uncontended linearizer lock %r for key %r",
+ self.name, key)
@contextmanager
def _ctx_manager():
@@ -211,7 +235,8 @@ class Linearizer(object):
yield
finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
- new_defer.callback(None)
+ with PreserveLoggingContext():
+ new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:
self.key_to_defer.pop(key, None)
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
new file mode 100644
index 0000000000..cdbc4bffd7
--- /dev/null
+++ b/synapse/util/logformatter.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import StringIO
+import logging
+import traceback
+
+
+class LogFormatter(logging.Formatter):
+ """Log formatter which gives more detail for exceptions
+
+ This is the same as the standard log formatter, except that when logging
+ exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+ sequence that led up to the point at which the exception was caught.
+ (Normally only stack frames between the point the exception was raised and
+ where it was caught are logged).
+ """
+ def __init__(self, *args, **kwargs):
+ super(LogFormatter, self).__init__(*args, **kwargs)
+
+ def formatException(self, ei):
+ sio = StringIO.StringIO()
+ (typ, val, tb) = ei
+
+ # log the stack above the exception capture point if possible, but
+ # check that we actually have an f_back attribute to work around
+ # https://twistedmatrix.com/trac/ticket/9305
+
+ if tb and hasattr(tb.tb_frame, 'f_back'):
+ sio.write("Capture point (most recent call last):\n")
+ traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+ traceback.print_exception(typ, val, tb, None, sio)
+ s = sio.getvalue()
+ sio.close()
+ if s[-1:] == "\n":
+ s = s[:-1]
+ return s
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
new file mode 100644
index 0000000000..4288312b8a
--- /dev/null
+++ b/synapse/util/module_loader.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+
+from synapse.config._base import ConfigError
+
+
+def load_module(provider):
+ """ Loads a module with its config
+ Take a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+
+ Returns
+ Tuple of (provider class, parsed config object)
+ """
+ # We need to import the module, and then pick the class out of
+ # that, so we split based on the last dot.
+ module, clz = provider['module'].rsplit(".", 1)
+ module = importlib.import_module(module)
+ provider_class = getattr(module, clz)
+
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
+
+ return provider_class, provider_config
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 2a203129ca..a5f47181d7 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
- hs.handlers = ProfileHandlers(hs)
-
self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test")
@@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
yield self.store.create_profile(self.frank.localpart)
- self.handler = hs.get_handlers().profile_handler
+ self.handler = hs.get_profile_handler()
@defer.inlineCallbacks
def test_get_my_name(self):
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c8cf9a63ec..e990e45220 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
self.hs = yield setup_test_homeserver(
handlers=None,
http_client=None,
- expire_access_token=True)
+ expire_access_token=True,
+ profile_handler=Mock(),
+ )
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret'))
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
- self.hs.get_handlers().profile_handler = Mock()
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 1e95e97538..dddcf51b69 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
resource_for_client=self.mock_resource,
federation=Mock(),
replication_layer=Mock(),
+ profile_handler=self.mock_handler
)
def _get_user_by_req(request=None, allow_guest=False):
@@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
hs.get_v1auth().get_user_by_req = _get_user_by_req
- hs.get_handlers().profile_handler = self.mock_handler
-
profile.register_servlets(hs, self.mock_resource)
@defer.inlineCallbacks
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d746ea8568..de376fb514 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
- token = "t1-0_0_0_0_0_0_0_0"
+ token = "t1-0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
- token = "s0_0_0_0_0_0_0_0"
+ token = "s0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b6173ab2ee..821c735528 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
+ self.hs.config.auto_join_rooms = []
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py
deleted file mode 100644
index 024ac15069..0000000000
--- a/tests/storage/event_injector.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-
-
-class EventInjector:
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.message_handler = hs.get_handlers().message_handler
- self.event_builder_factory = hs.get_event_builder_factory()
-
- @defer.inlineCallbacks
- def create_room(self, room, user):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Create,
- "sender": user.to_string(),
- "room_id": room.to_string(),
- "content": {},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
-
- @defer.inlineCallbacks
- def inject_room_member(self, room, user, membership):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Member,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"membership": membership},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
-
- defer.returnValue(event)
-
- @defer.inlineCallbacks
- def inject_message(self, room, user, body):
- builder = self.event_builder_factory.new({
- "type": EventTypes.Message,
- "sender": user.to_string(),
- "state_key": user.to_string(),
- "room_id": room.to_string(),
- "content": {"body": body, "msgtype": u"message"},
- })
-
- event, context = yield self.message_handler._create_new_client_event(
- builder
- )
-
- yield self.store.persist_event(event, context)
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index afcba482f9..793a88e462 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
+from synapse.util import async, logcontext
from tests import unittest
from twisted.internet import defer
@@ -38,7 +37,28 @@ class LinearizerTestCase(unittest.TestCase):
with cm1:
self.assertFalse(d2.called)
- self.assertTrue(d2.called)
-
with (yield d2):
pass
+
+ def test_lots_of_queued_things(self):
+ # we have one slow thing, and lots of fast things queued up behind it.
+ # it should *not* explode the stack.
+ linearizer = Linearizer()
+
+ @defer.inlineCallbacks
+ def func(i, sleep=False):
+ with logcontext.LoggingContext("func(%s)" % i) as lc:
+ with (yield linearizer.queue("")):
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(), lc)
+ if sleep:
+ yield async.sleep(0)
+
+ self.assertEqual(
+ logcontext.LoggingContext.current_context(), lc)
+
+ func(0, sleep=True)
+ for i in xrange(1, 100):
+ func(i)
+
+ return func(1000)
diff --git a/tests/util/test_log_context.py b/tests/util/test_logcontext.py
index 9ffe209c4d..e2f7765f49 100644
--- a/tests/util/test_log_context.py
+++ b/tests/util/test_logcontext.py
@@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
yield defer.succeed(None)
return self._test_preserve_fn(nonblocking_function)
+
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable(self):
+ # a function which retuns an incomplete deferred, but doesn't follow
+ # the synapse rules.
+ def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ return d
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.test_key = "one"
+
+ d1 = logcontext.make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_on_non_deferred(self):
+ """Check that make_deferred_yieldable does the right thing when its
+ argument isn't actually a deferred"""
+
+ with LoggingContext() as context_one:
+ context_one.test_key = "one"
+
+ d1 = logcontext.make_deferred_yieldable("bum")
+ self._check_test_key("one")
+
+ r = yield d1
+ self.assertEqual(r, "bum")
+ self._check_test_key("one")
|