From ae6c4c1b51cafb4d5534b5849710560bcafd32e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20R=2E=20Sede=C3=B1o?= Date: Sun, 13 Sep 2020 13:58:11 -0400 Subject: [PATCH] functioning prototype This is good enough to interact with it as a dev instance with a slightly modified snipe. WARNING: This commit reinitializes migrations. Not that anyone besides me was using this before since it didn't really work. Old databases should be discarded. --- roost_backend/apps.py | 6 + roost_backend/consumers.py | 130 +++++++-- roost_backend/migrations/0001_initial.py | 24 +- roost_backend/models.py | 107 ++++++-- roost_backend/secrets.py | 3 + roost_backend/serializers.py | 29 +- roost_backend/signals.py | 82 ++++++ roost_backend/user_process.py | 328 +++++++++++++++++------ roost_backend/utils/__init__.py | 26 +- roost_backend/utils/kerberos.py | 56 +++- roost_backend/views.py | 29 +- roost_ng/settings/gssapi.py | 3 + roost_ng/settings/logging.py | 5 + 13 files changed, 660 insertions(+), 168 deletions(-) create mode 100644 roost_backend/signals.py diff --git a/roost_backend/apps.py b/roost_backend/apps.py index 46360fe..0a6c541 100644 --- a/roost_backend/apps.py +++ b/roost_backend/apps.py @@ -3,3 +3,9 @@ class RoostBackendConfig(AppConfig): name = 'roost_backend' + + def ready(self): + # pylint: disable=import-outside-toplevel, unused-import + # This is for side-effects of hooking up signals. + from . import signals # noqa: F401 + super().ready() diff --git a/roost_backend/consumers.py b/roost_backend/consumers.py index 67c8799..1cbddb2 100644 --- a/roost_backend/consumers.py +++ b/roost_backend/consumers.py @@ -1,15 +1,16 @@ import logging +from asgiref.sync import async_to_sync from channels.generic.websocket import JsonWebsocketConsumer +from djangorestframework_camel_case.util import camelize from .authentication import JWTAuthentication +from . import filters, serializers, utils _LOGGER = logging.getLogger(__name__) class UserSocketConsumer(JsonWebsocketConsumer): - groups = ['broadcast'] - class BadMessage(Exception): pass @@ -17,9 +18,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.user = None self.tails = {} - - def connect(self): - self.accept() + self.active_tails = set() def receive_json(self, content, **kwargs): msg_type = content.get('type') @@ -33,6 +32,9 @@ def receive_json(self, content, **kwargs): self.close(code=4002) return self.user = user + async_to_sync(self.channel_layer.group_add)( + utils.principal_to_user_socket_group_name(user.principal), + self.channel_name) self.send_json({'type': 'ready'}) return @@ -56,21 +58,18 @@ def on_new_tail(self, content): inclusive = content.get('inclusive', False) if not all((isinstance(tail_id, int), - start is None or not isinstance(start, str), + start is None or isinstance(start, str), isinstance(inclusive, bool))): raise self.BadMessage() if start is None: start = 0 else: - # TODO: unseal message_id `start` + start = utils.unseal_message_id(start) if inclusive: - # TODO: if `inclusive`, decrement `start` - pass + start -= 1 - # TODO: construct filter from `content` - t_filter = None - # t_filter = Filter(content) + t_filter = filters.MessageFilter(**content) if tail_id in self.tails: # Roost frowned upon reusing tail ids in comments, and then closed the existing tail @@ -78,18 +77,16 @@ def on_new_tail(self, content): _LOGGER.debug('User "%s" has reused tail id "%i".', self.user, tail_id) self.tails[tail_id].close() - self.tails['tail_id'] = Tail(self, tail_id, start, t_filter) - - raise NotImplementedError() + self.tails[tail_id] = Tail(self, tail_id, start, t_filter) def on_extend_tail(self, content): tail_id = content.get('id') count = content.get('count') if not all((isinstance(tail_id, int), - isinstance(count, int))): + isinstance(count, int), + tail_id in self.tails)): raise self.BadMessage() - - raise NotImplementedError() + self.tails[tail_id].extend(count) def on_close_tail(self, content): tail_id = content.get('id') @@ -99,10 +96,27 @@ def on_close_tail(self, content): if tail_id in self.tails: self.tails.pop(tail_id).close() - def disconenct(self, close_code): - _LOGGER.debug('WebSocket for user "%s" closed by client with code "%s".', self.user, close_code) + def disconnect(self, code): + _LOGGER.debug('WebSocket for user "%s" closed by client with code "%s".', self.user, code) + + if self.user is not None: + async_to_sync(self.channel_layer.group_discard)( + utils.principal_to_user_socket_group_name(self.user.principal), + self.channel_name) + + for tail in self.tails.values(): + tail.close() + self.tails = {} + self.close() + # Start of Channel Layer message handlers + def incoming_message(self, message): + # don't iterate over active_tails itself as its size may change while we do that. + for tail in list(self.active_tails): + tail.on_message(message['message']) + # End message handlers + class Tail: def __init__(self, socket, t_id, start, t_filter): @@ -114,20 +128,84 @@ def __init__(self, socket, t_id, start, t_filter): self.active = False self.messages_sent = 0 self.messages_wanted = 0 + self.message_buffer = None def close(self): + self.deactivate() self.socket = None - # TODO: stop doing things, once we figure out what things are. - raise NotImplementedError() def extend(self, count): - pass + _LOGGER.debug('tail: extending %i', count) + + self.messages_wanted = max(count - self.messages_sent, + self.messages_wanted) + self.do_query() def activate(self): - pass + if not self.active: + self.active = True + self.socket.active_tails.add(self) def deactivate(self): - pass + if self.active: + self.active = False + self.socket.active_tails.remove(self) def do_query(self): - pass + if self.socket is None: + return + + if self.active: + return + + if self.messages_wanted == 0: + return + + self.activate() + self.message_buffer = [] + qs = self.user.message_set.filter(id__gt=self.last_sent) + qs = self.t_filter.apply_to_queryset(qs)[:self.messages_wanted] + messages = [{'id': msg.id, + 'payload': serializers.MessageSerializer(msg).data, + } for msg in list(qs)] + _LOGGER.debug('tail query returned %i messages', len(messages)) + self.emit_messages(messages) + + if self.messages_wanted: + message_buffer, self.message_buffer = self.message_buffer, None + message_buffer = [msg for msg in message_buffer if msg.id > self.last_sent] + self.emit_messages(message_buffer) + + if not self.messages_wanted: + self.deactivate() + + def on_message(self, message): + if not self.socket: + return + if not self.t_filter.matches_message(message): + return + if self.last_sent >= message['id']: + return + if isinstance(self.message_buffer, list): + self.message_buffer.append(message) + return + + self.emit_messages([message]) + if self.messages_wanted == 0: + self.deactivate() + + def emit_messages(self, messages): + if messages: + self.socket.send_json({ + 'type': 'messages', + 'id': self.t_id, + 'messages': [camelize(msg['payload']) for msg in messages], + 'isDone': True, + }) + count = len(messages) + self.messages_sent += count + if count >= self.messages_wanted: + self.messages_wanted = 0 + else: + self.messages_wanted -= count + self.last_sent = messages[-1]['id'] diff --git a/roost_backend/migrations/0001_initial.py b/roost_backend/migrations/0001_initial.py index 6cdc555..0299f2e 100644 --- a/roost_backend/migrations/0001_initial.py +++ b/roost_backend/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.0.8 on 2020-07-18 15:08 +# Generated by Django 3.1.1 on 2020-09-13 17:49 from django.db import migrations, models import django.db.models.deletion @@ -12,15 +12,29 @@ class Migration(migrations.Migration): ] operations = [ + migrations.CreateModel( + name='ServerProcessState', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('data', models.JSONField()), + ], + ), migrations.CreateModel( name='User', fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('principal', models.CharField(max_length=255, unique=True)), - ('info', models.TextField(default='{}')), + ('info', models.JSONField(default=dict)), ('info_version', models.BigIntegerField(default=1)), ], ), + migrations.CreateModel( + name='UserProcessState', + fields=[ + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='process_state', serialize=False, to='roost_backend.user')), + ('data', models.JSONField()), + ], + ), migrations.CreateModel( name='Subscription', fields=[ @@ -30,7 +44,7 @@ class Migration(migrations.Migration): ('zrecipient', models.CharField(max_length=255)), ('class_key', models.CharField(max_length=255)), ('instance_key', models.CharField(max_length=255)), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='roost_backend.User')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='roost_backend.user')), ], options={ 'unique_together': {('user', 'zrecipient', 'class_key', 'instance_key')}, @@ -57,12 +71,12 @@ class Migration(migrations.Migration): ('uid', models.CharField(max_length=16)), ('opcode', models.CharField(blank=True, max_length=255)), ('signature', models.CharField(max_length=255)), - ('message', models.BinaryField()), + ('message', models.TextField()), ('users', models.ManyToManyField(to='roost_backend.User')), ], options={ 'ordering': ['id'], - 'index_together': {('class_key_base', 'instance_key_base'), ('class_key', 'instance_key')}, + 'index_together': {('class_key', 'instance_key'), ('class_key_base', 'instance_key_base')}, }, ), ] diff --git a/roost_backend/models.py b/roost_backend/models.py index 660c409..fd01d72 100644 --- a/roost_backend/models.py +++ b/roost_backend/models.py @@ -1,7 +1,9 @@ +import base64 import datetime +import re +import socket +import struct -from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer from django.conf import settings from django.db import models, transaction import jwt @@ -11,7 +13,7 @@ class User(models.Model): principal = models.CharField(max_length=255, unique=True) - info = models.TextField(default='{}') + info = models.JSONField(default=dict) info_version = models.BigIntegerField(default=1) # TODO: add minimum token age or token generation to invalidate old tokens. @@ -80,24 +82,13 @@ def is_authenticated(self): def is_anonymous(self): return self.id is None - @staticmethod - def _send_to_group(msg, group_name, wait_for_response=False): - channel_layer = get_channel_layer() - if wait_for_response: - channel_name = async_to_sync(channel_layer.new_channel)() - msg = dict(msg, _reply_to=channel_name) - async_to_sync(channel_layer.group_send)(group_name, msg) - if wait_for_response: - return async_to_sync(channel_layer.receive)(channel_name) - return None - def send_to_user_process(self, msg, wait_for_response=False): group_name = utils.principal_to_user_process_group_name(self.principal) - return self._send_to_group(msg, group_name, wait_for_response) + return utils.send_to_group(group_name, msg, wait_for_response) def send_to_user_sockets(self, msg, wait_for_response=False): group_name = utils.principal_to_user_socket_group_name(self.principal) - return self._send_to_group(msg, group_name, wait_for_response) + return utils.send_to_group(group_name, msg, wait_for_response) class Meta: pass @@ -120,6 +111,9 @@ class Meta: ] +RE_BASE_STR = re.compile(r'(?:un)*(.*?)(?:[.]d)*') + + class Message(models.Model): users = models.ManyToManyField('User') # display data @@ -150,7 +144,58 @@ class Message(models.Model): opcode = models.CharField(max_length=255, blank=True) signature = models.CharField(max_length=255) - message = models.BinaryField() + message = models.TextField() + + def __str__(self): + return f'[{self.uid}] {self.class_key},{self.instance_key},{self.recipient if self.recipient else "*"}' + + @classmethod + def from_notice(cls, notice, is_outgoing=False): + # Further needed arguments: direction, user? + + def _d(octets: bytes) -> str: + # pylint: disable=protected-access + if notice._charset == b'UTF-8': + return octets.decode('utf-8') + if notice._charset == b'ISO-8859-1': + return octets.decode('latin-1') + for enc in ('ascii', 'utf-8', 'latin-1'): + try: + return octets.decode(enc) + except UnicodeDecodeError: + pass + + msg = cls() + msg.zclass = _d(notice.cls) + msg.zinstance = _d(notice.instance) + msg.class_key = msg.zclass.casefold() + msg.instance_key = msg.zinstance.casefold() + msg.class_key_base = RE_BASE_STR.fullmatch(msg.class_key).group(1) + msg.instance_key_base = RE_BASE_STR.fullmatch(msg.instance_key).group(1) + msg.time = datetime.datetime.fromtimestamp(notice.time or notice.uid.time, datetime.timezone.utc) + msg.auth = notice.auth + msg.sender = _d(notice.sender) + msg.recipient = _d(notice.recipient) + + msg.is_personal = bool(msg.recipient) + msg.is_outgoing = is_outgoing + if msg.is_personal: + msg.conversation = msg.recipient if is_outgoing else msg.sender + + # Reconstruct the Zuid from its component parts and store it like roost would. + uid = socket.inet_aton(notice.uid.address.decode('ascii')) + uid_time = datetime.datetime.fromtimestamp(notice.uid.time, datetime.timezone.utc) + uid += struct.pack('!II', int(uid_time.timestamp()), int(uid_time.microsecond)) + msg.uid = base64.b64encode(uid).decode('ascii') + + msg.opcode = _d(notice.opcode) + if len(notice.fields) == 2: + msg.signature = _d(notice.fields[0])[:255] + msg.message = _d(notice.fields[1]) + elif len(notice.fields) == 1: + msg.message = _d(notice.fields[0]) + + return msg class Meta: index_together = [ @@ -158,3 +203,31 @@ class Meta: ['class_key_base', 'instance_key_base'], ] ordering = ['id'] + + +class UserProcessState(models.Model): + """This class will be used to persist data the user process needs. The + `data` field format is defined by user_process.py. This table is + new to roost-ng and internal only, not to be exposed to clients. + """ + + user = models.OneToOneField('User', primary_key=True, on_delete=models.CASCADE, related_name='process_state') + data = models.JSONField() + + +class ServerProcessState(models.Model): + """This class will be used to persist data the server process needs. The + `data` field format is defined by user_process.py. This table is + new to roost-ng and internal only, not to be exposed to clients. + """ + + data = models.JSONField() + + def save(self, *args, **kwargs): + # pylint: disable=signature-differs + self.__class__.objects.exclude(id=self.id).delete() + super().save(*args, **kwargs) + + @classmethod + def load(cls): + return cls.objects.last() or cls() diff --git a/roost_backend/secrets.py b/roost_backend/secrets.py index efbee04..0258538 100644 --- a/roost_backend/secrets.py +++ b/roost_backend/secrets.py @@ -3,7 +3,10 @@ from django.conf import settings + _BASE_KEY = hashlib.blake2s(settings.SECRET_KEY.encode('utf-8'), digest_size=16).digest() + + def _secret_generator(context, as_uuid=False): """Derive secrets from SECRET_KEY and a context string""" ret = uuid.uuid5(uuid.UUID(bytes=_BASE_KEY), context) diff --git a/roost_backend/serializers.py b/roost_backend/serializers.py index 2741b35..506e1c2 100644 --- a/roost_backend/serializers.py +++ b/roost_backend/serializers.py @@ -4,7 +4,7 @@ from rest_framework import serializers -from . import models +from . import models, utils # pylint: disable=abstract-method # pylint complains about missing optional `create` and `update` methods in @@ -31,6 +31,14 @@ def to_internal_value(self, data): return datetime.datetime.fromtimestamp(data/1000, datetime.timezone.utc) +class SealedMessageIdField(serializers.UUIDField): + def to_representation(self, value): + return super().to_representation(utils.seal_message_id(value)) + + def to_internal_value(self, data): + return utils.unseal_message_id(super().to_internal_value(data)) + + # Kerberos Credential Serializers class _InlineNameSerializer(serializers.Serializer): name_type = serializers.IntegerField() @@ -154,14 +162,15 @@ class Meta: class MessageSerializer(serializers.ModelSerializer): + id = SealedMessageIdField() time = DateTimeAsMillisecondsField() receive_time = DateTimeAsMillisecondsField() - message = serializers.SerializerMethodField() + # message = serializers.SerializerMethodField() instance = serializers.CharField(source='zinstance') - @staticmethod - def get_message(obj): - return obj.message.decode('utf-8') + # @staticmethod + # def get_message(obj): + # return obj.message.decode('utf-8') class Meta: model = models.Message @@ -175,11 +184,11 @@ class Meta: class OutgoingMessageSerializer(serializers.Serializer): - instance = serializers.CharField() - recipient = serializers.CharField() - opcode = serializers.CharField(default='') - signature = serializers.CharField(default='') - message = serializers.CharField() + instance = serializers.CharField(allow_blank=True) + recipient = serializers.CharField(allow_blank=True) + opcode = serializers.CharField(default='', allow_blank=True) + signature = serializers.CharField(default='', allow_blank=True) + message = serializers.CharField(allow_blank=True) # class is a reserved word, so let's do this the hard way. diff --git a/roost_backend/signals.py b/roost_backend/signals.py new file mode 100644 index 0000000..cf28994 --- /dev/null +++ b/roost_backend/signals.py @@ -0,0 +1,82 @@ +from django.db.models.signals import post_delete, post_save +from django.dispatch import receiver + +from . import models, serializers, utils + + +@receiver(post_save, sender=models.Message) +def message_post_processing(sender, instance, created, **_kwargs): + # pylint: disable=unused-argument + if not created: + return + users = [] + if instance.is_personal: + if instance.is_outgoing: + users.append(models.User.objects.get(principal=instance.sender)) + else: + users.append(models.User.objects.get(principal=instance.recipient)) + elif not instance.is_outgoing: + users.extend(sub.user for sub in + models.Subscription.objects.filter( + class_key=instance.class_key, + instance_key__in=(instance.instance_key, '*'), + zrecipient=instance.recipient)) + + if users: + instance.users.add(*users) + payload = serializers.MessageSerializer(instance).data + for user in users: + user.send_to_user_sockets({ + 'type': 'incoming_message', + 'message': { + 'id': instance.id, + 'payload': payload, + } + }) + + +@receiver(post_save, sender=models.Subscription) +def resync_subscriber_on_subscription_save(sender, instance, created, **_kwargs): + # pylint: disable=unused-argument + if not created: + return + user = instance.user + payload = {'type': 'resync_subscriptions'} + if instance.zrecipient == user.principal: + # personal; send to user process + user.send_to_user_process(payload) + else: + utils.send_to_group('ROOST_SERVER_PROCESS', payload) + + +@receiver(post_delete, sender=models.Subscription) +def resync_subscriber_on_subscription_delete(sender, instance, **_kwargs): + # pylint: disable=unused-argument + user = instance.user + payload = {'type': 'resync_subscriptions'} + if not user: + return + if instance.zrecipient == user.principal: + # personal; send to user process + user.send_to_user_process(payload) + else: + utils.send_to_group('ROOST_SERVER_PROCESS', payload) + + +@receiver(post_save, sender=models.User) +def start_new_user_process(sender, instance, created, **_kwargs): + # pylint: disable=unused-argument + if created: + utils.send_to_group('UP_OVERSEER', { + 'type': 'add_user', + 'principal': instance.principal, + }) + + +@receiver(post_delete, sender=models.User) +def resync_subscriber_on_user_delete(sender, instance, **_kwargs): + # pylint: disable=unused-argument + utils.send_to_group('UP_OVERSEER', { + 'type': 'del_user', + 'principal': instance.principal, + }) diff --git a/roost_backend/user_process.py b/roost_backend/user_process.py index 8cc77e8..ce1f239 100644 --- a/roost_backend/user_process.py +++ b/roost_backend/user_process.py @@ -1,18 +1,26 @@ import asyncio +import base64 import functools import logging import multiprocessing as mp import os +import random +import select import signal from asgiref.sync import sync_to_async, async_to_sync import channels.consumer +from channels.db import database_sync_to_async import channels.layers import channels.utils import django import django.apps from django.core.exceptions import AppRegistryNotReady +from django.db import IntegrityError, transaction +from djangorestframework_camel_case.util import underscoreize import setproctitle +import zephyr +import _zephyr from . import utils @@ -21,8 +29,9 @@ class _MPDjangoSetupMixin: """This mixin runs django.setup() on __init__. It is to be used by classes that are - mp.Process targets. - """ + mp.Process targets.""" + # pylint: disable=too-few-public-methods + def __init__(self): try: django.apps.apps.check_models_ready() @@ -35,8 +44,8 @@ class _ChannelLayerMixin: """This mixin can be used to add Django Channels Layers support to a class. To ues it, inherit from it and define a member `groups` or property `groups` of no arguments that returns an iterable of groups to subscribe to. Then start a task to run the `channel_layer_handler`, cancel it when you - want to stop. This may be worth extracting to a utility module. - """ + want to stop. This may be worth extracting to a utility module.""" + def __init__(self): super().__init__() self.channel_layer = None @@ -78,13 +87,17 @@ async def dispatch(self, message): raise ValueError(f'No handler for message type "{message["type"]}"') -class _ZephyrProcessMixin: +class _ZephyrProcessMixin(_ChannelLayerMixin): + """This mixin contains the core zephyr support for the User Processes and Server Process.""" + def __init__(self): super().__init__() # Event to indicate that zephyr has been initialized self.z_initialized = mp.Event() # Lock to be used around non-threadsafe bits of libzephyr. - self.zephyr_lock = mp.Lock() + self.zephyr_lock = None + self.resync_event = None + self.waiting_for_acks = {} @property def principal(self): @@ -103,7 +116,7 @@ def _initialize_memory_ccache(self): def _add_credential_to_ccache(self, creds): utils.kerberos.add_credential_to_ccache(creds, self.principal) self.zinit() - self.resync_subs() + self.resync_event.set() @staticmethod def _have_valid_zephyr_creds(): @@ -117,31 +130,45 @@ def zinit(self): zephyr.init() self.z_initialized.set() + async def _sub(self, zsub, sub): + async with self.zephyr_lock: + zsub.add(sub) + def resync_subs(self): if not self.z_initialized.is_set(): return - _LOGGER.debug('[%s] zinit done, subscribing...', self.log_prefix) + _LOGGER.debug('[%s] resyncing subscriptions...', self.log_prefix) zsub = zephyr.Subscriptions() - if self.principal is not None: - # Don't unsub when destroying the Subscriptions object so we can use dump/loadSession. - # Only relevant for user process (with has principal) - zsub.cleanup = False + zsub.cleanup = False zsub.resync() subs_qs = self.get_subs_qs() for sub in set(subs_qs.values_list('class_key', 'instance_key', 'zrecipient')): _LOGGER.debug(' %s', sub) - with self.zephyr_lock: - zsub.add(sub) + async_to_sync(self._sub)(zsub, sub) _LOGGER.debug('[%s] %s', self.log_prefix, zsub) # TODO: check for extra subs and get rid of them. _LOGGER.debug('[%s] subscribing done.', self.log_prefix) + async def resync_handler(self): + _LOGGER.debug('[%s] resync task started.', self.log_prefix) + try: + while True: + await self.resync_event.wait() + self.resync_event.clear() + _LOGGER.debug('[%s] resync task triggered.', self.log_prefix) + await database_sync_to_async(self.resync_subs)() + except asyncio.CancelledError: + _LOGGER.debug('[%s] resync task cancelled.', self.log_prefix) + async def zephyr_handler(self): + self.zephyr_lock = asyncio.Lock() + self.resync_event = asyncio.Event() + resync_task = asyncio.create_task(self.resync_handler()) _LOGGER.debug('[%s] zephyr handler started.', self.log_prefix) try: await self.load_user_data() @@ -150,7 +177,8 @@ async def zephyr_handler(self): await sync_to_async(self.z_initialized.wait)() _LOGGER.debug('[%s] zephyr handler now receiving...', self.log_prefix) while True: - with self.zephyr_lock: + _LOGGER.debug('[%s] zephyr handler loop start...', self.log_prefix) + async with self.zephyr_lock: # Since we're calling this non-blocking, not bothering to wrap and await. notice = zephyr.receive() @@ -162,14 +190,20 @@ async def zephyr_handler(self): _LOGGER.debug('[%s] data on FD...', self.log_prefix) continue - _LOGGER.debug('%s, %s', notice, notice.kind) + _LOGGER.debug('[%s] got: %s, %s', self.log_prefix, notice, notice.kind) + if notice.kind == zephyr.ZNotice.Kind.hmack: + # Ignore HM Acks + continue if notice.kind in (zephyr.ZNotice.Kind.servnak, - zephyr.ZNotice.Kind.servack, - zephyr.ZNotice.Kind.hmack): - # It would be cool to send ACK/NAKs to the user, - # but it is not clear what roost actually sent back, - # and no client actually did more than log it. + zephyr.ZNotice.Kind.servack): + # TODO: maybe do something different for servnak? + key = utils.notice_to_zuid_key(notice) + ack_reply_channel = self.waiting_for_acks.pop(key, None) + if ack_reply_channel: + await self.channel_layer.send(ack_reply_channel, { + 'ack': notice.fields[0].decode('utf-8') + }) continue if notice.opcode.lower() == 'ping': # Ignoring pings @@ -181,49 +215,74 @@ async def zephyr_handler(self): except asyncio.CancelledError: _LOGGER.debug('[%s] zephyr handler cancelled.', self.log_prefix) await self.save_user_data() + resync_task.cancel() + await resync_task finally: _LOGGER.debug('[%s] zephyr handler done.', self.log_prefix) @database_sync_to_async - def load_user_data(self): + def _load_user_data(self): if self.principal is None: - # Nothing to do for server process - return + # Server process + obj = django.apps.apps.get_model('roost_backend', 'ServerProcessState').load() + return obj.data obj = django.apps.apps.get_model('roost_backend', 'UserProcessState').objects.filter( user__principal=self.principal).first() if obj: - data = json.loads(obj.data) + return obj.data + return None + + async def load_user_data(self): + data = await self._load_user_data() + if data: if 'session_data' in data: # If we have session data, reinitialize libzephyr with it. session_data = base64.b64decode(data['session_data']) - with self.zephyr_lock: - zephyr.init(session_data=session_data) - self.z_initialized.set() + try: + async with self.zephyr_lock: + zephyr.init(session_data=session_data) + self.z_initialized.set() + except OSError: + pass if 'kerberos_data' in data and data['kerberos_data']: # If we have credentials, inject them into our ccache. # This will also initialize libzephyr if there was no session data. # TODO: filter out expired credentials? # TODO: support importing a list of credentials. - self._add_credential_to_ccache(data['kerberos_data']) - # obj.delete() + await database_sync_to_async(self._add_credential_to_ccache)(data['kerberos_data']) + if self.principal is None: + # The server process always has credentials; if we did not load state, initialize things now. + await sync_to_async(self.zinit)() + await database_sync_to_async(self.resync_subs)() @database_sync_to_async - def save_user_data(self): + def _save_user_data(self, data): if self.principal is None: - if not self._have_valid_zephyr_creds(): - utils.kerberos.initialize_memory_ccache_from_client_keytab() - with self.zephyr_lock: - _zephyr.cancelSubs() - # Nothing to do for server process - return + obj = django.apps.apps.get_model('roost_backend', 'ServerProcessState').load() + if 'kerberos_data' in data: + del data['kerberos_data'] + obj.data = data + obj.save() + else: + ups = django.apps.apps.get_model('roost_backend', 'UserProcessState') + try: + with transaction.atomic(): + ups.objects.update_or_create(user_id=self.uid, defaults={ + 'data': data, + }) + except IntegrityError: + _LOGGER.debug('[%s] saving user data failed; user deleted?', self.log_prefix) + return + _LOGGER.debug('[%s] saving user data done.', self.log_prefix) + + async def save_user_data(self): # TODO: support exporting multiple credentials. if not self.z_initialized.is_set(): return _LOGGER.debug('[%s] saving user data...', self.log_prefix) - ups = django.apps.apps.get_model('roost_backend', 'UserProcessState') - with self.zephyr_lock: + async with self.zephyr_lock: zephyr_session = _zephyr.dumpSession() zephyr_realm = _zephyr.realm() data = { @@ -231,14 +290,50 @@ def save_user_data(self): 'kerberos_data': underscoreize(utils.kerberos.get_zephyr_creds_dict(zephyr_realm)), } - try: - with transaction.atomic(): - ups.objects.update_or_create(user_id=self.uid, defaults={ - 'data': json.dumps(data), - }) - _LOGGER.debug('[%s] saving user data done.', self.log_prefix) - except IntegrityError: - _LOGGER.debug('[%s] saving user data failed; user deleted?', self.log_prefix) + for _ in range(4): + try: + await self._save_user_data(data) + break + except django.db.utils.OperationalError: + _LOGGER.warning('[%s] saving user data failed, trying again...', self.log_prefix) + await asyncio.sleep(random.random()) # jitter + else: + _LOGGER.error('[%s] saving user data failed, giving up.', self.log_prefix) + + # Start of Channel Layer message handlers + async def zwrite(self, message): + await sync_to_async(self.zinit)() + msg_args = message['message'] + reply_channel = message.pop('_reply_to', None) + + notice_args = { + k: v.encode() + for k, v in msg_args.items() + } + + if notice_args['recipient'].startswith(b'*'): + notice_args['recipient'] = notice_args['recipient'][1:] + notice_args['cls'] = notice_args.pop('class') + + notice = zephyr.ZNotice(**notice_args) + + async with self.zephyr_lock: + await sync_to_async(notice.send)() + if reply_channel is not None: + # Doing this under the lock ensures that we put the reply_channel in the dict before + # we can process any ACK. + self.waiting_for_acks[utils.notice_to_zuid_key(notice)] = reply_channel + + msg = django.apps.apps.get_model('roost_backend', 'Message').from_notice(notice, is_outgoing=True) + _LOGGER.debug('%s', msg) + if msg.is_personal: + # Only save outbound personals. + # TODO: re-evaluate this decision. + await database_sync_to_async(msg.save)() + + async def resync_subscriptions(self, _message): + self.resync_event.set() + # End message handlers class Manager: @@ -303,7 +398,6 @@ def __init__(self, stop_event, start=True): def __str__(self): return f'Overseer<{self.pid}>' - def start(self): setproctitle.setproctitle('roost:OVERSEER') user_qs = django.apps.apps.get_model('roost_backend', 'User').objects.all() @@ -315,7 +409,9 @@ def start(self): async_to_sync(self.oversee)() async def oversee(self): + _LOGGER.debug('[OVERSEER] starting...') channel_task = asyncio.create_task(self.channel_layer_handler()) + server_task = asyncio.create_task(self.server_process_watcher()) for princ, task in self.user_tasks.items(): if task is None: self.user_tasks[princ] = asyncio.create_task(self.user_process_watcher(princ)) @@ -323,9 +419,21 @@ async def oversee(self): # We could just wait for the async tasks to finish, but then # we would not be waiting on any tasks for users created after # start-up, once we handle dynamic user creation. + _LOGGER.debug('[OVERSEER] waiting for stop event...') await sync_to_async(self.stop_event.wait, thread_sensitive=True)() - await asyncio.wait([task for task in self.user_tasks.values() if task is not None]) + _LOGGER.debug('[OVERSEER] received stop event...') + tasks = [server_task] + tasks.extend(task for task in self.user_tasks.values() if task is not None) + await asyncio.wait(tasks) channel_task.cancel() + await channel_task + _LOGGER.debug('[OVERSEER] done.') + + async def server_process_watcher(self): + while not self.stop_event.is_set(): + proc = mp.Process(target=ServerProcess, args=(self.stop_event,)) + proc.start() + await sync_to_async(proc.join, thread_sensitive=True)() async def user_process_watcher(self, principal): while not self.stop_event.is_set(): @@ -340,18 +448,25 @@ async def add_user(self, message): princ = message['principal'] if princ not in self.user_tasks: self.user_tasks[princ] = asyncio.create_task(self.user_process_watcher(princ)) + + async def del_user(self, message): + # {'type': 'del_user', 'principal': ''} + # Kills user process for user if running. + princ = message['principal'] + task = self.user_tasks.pop(princ, None) + if task: + task.cancel() # End message handlers -class UserProcess(_MPDjangoSetupMixin, _ChannelLayerMixin): - """ - Kerberos and zephyr are not particularly threadsafe, so each user - will have their own process. - """ +class UserProcess(_MPDjangoSetupMixin, _ZephyrProcessMixin): + """Kerberos and zephyr are not particularly threadsafe, so each user + will have their own process.""" def __init__(self, principal, stop_event, start=True): super().__init__() - self.principal = principal + self._principal = principal + self.uid = None self.stop_event = stop_event if start: self.start() @@ -359,56 +474,95 @@ def __init__(self, principal, stop_event, start=True): def __str__(self): return f'UserProcess<{self.principal}>' - def _initialize_memory_ccache(self): - utils.kerberos.initialize_memory_ccache(self.principal) - - def _add_credential_to_ccache(self, creds): - utils.kerberos.add_credential_to_ccache(creds) - - @property def groups(self): # The _ChannelLayerMixin requires us to define this. return [utils.principal_to_user_process_group_name(self.principal)] + @property + def principal(self): + # The _ZephyrProcessMixin requires us to define this. + return self._principal + + def get_subs_qs(self): + # The _ZephyrProcessMixin requires us to define this. + subs_qs = django.apps.apps.get_model('roost_backend', 'Subscription').objects.all() + subs_qs = subs_qs.filter(user__principal=self.principal, zrecipient=self.principal) + return subs_qs + def start(self): + _LOGGER.debug('%s starting...', self) setproctitle.setproctitle(f'roost:{self.principal}') + self.uid = django.apps.apps.get_model('roost_backend', 'User').objects.get(principal=self.principal).id self._initialize_memory_ccache() async_to_sync(self.run)() async def run(self): - channel_task = asyncio.create_task(self.channel_layer_handler()) zephyr_task = asyncio.create_task(self.zephyr_handler()) - await sync_to_async(self.stop_event.wait, thread_sensitive=True)() - channel_task.cancel() - zephyr_task.cancel() + channel_task = asyncio.create_task(self.channel_layer_handler()) + try: + await sync_to_async(self.stop_event.wait, thread_sensitive=True)() + finally: + channel_task.cancel() + zephyr_task.cancel() + await zephyr_task + await channel_task # Start of Channel Layer message handlers - async def test(self, message): - print(self.principal, 'test', message) + async def inject_credentials(self, message): + await database_sync_to_async(self._add_credential_to_ccache)(message['creds']) - async def zwrite(self, message): - print(self.principal, 'zwrite', message['message']) - reply_channel = message.get('_reply_to') - if reply_channel is not None: - await self.channel_layer.send(reply_channel, {'ack': 'stubbed'}) + async def have_valid_credentials(self, message): + reply_channel = message.pop('_reply_to') + valid_creds = await sync_to_async(self._have_valid_zephyr_creds)() + await self.channel_layer.send(reply_channel, {'valid': valid_creds}) + # End message handlers - async def subscribe(self, message): - print(self.principal, 'subscribe', message) - async def unsubscribe(self, message): - print(self.principal, 'unsubscribe', message) +class ServerProcess(_MPDjangoSetupMixin, _ZephyrProcessMixin): + """Like the UserProcess, but for shared subscriptions.""" - async def inject_credentials(self, message): - self._add_credential_to_ccache(message['creds']) - # End message handlers + def __init__(self, stop_event, start=True): + super().__init__() + self.uid = None + self.stop_event = stop_event + if start: + self.start() - async def zephyr_handler(self): - _LOGGER.debug('faux zephyr handler started.') + def __str__(self): + return 'ServerProcess' + + # The _ChannelLayerMixin requires us to define this. + groups = ['ROOST_SERVER_PROCESS'] + + @property + def principal(self): + # The _ZephyrProcessMixin requires us to define this. + return None + + @property + def log_prefix(self): + return 'ServerProcess' + + def get_subs_qs(self): + # The _ZephyrProcessMixin requires us to define this. + subs_qs = django.apps.apps.get_model('roost_backend', 'Subscription').objects.all() + subs_qs = subs_qs.filter(zrecipient='') + return subs_qs + + def start(self): + _LOGGER.debug('%s starting...', self) + setproctitle.setproctitle('roost:server_process') + utils.kerberos.initialize_memory_ccache_from_client_keytab() + async_to_sync(self.run)() + + async def run(self): + zephyr_task = asyncio.create_task(self.zephyr_handler()) + channel_task = asyncio.create_task(self.channel_layer_handler()) try: - while True: - await asyncio.sleep(1) - except asyncio.CancelledError: - _LOGGER.debug('faux zephyr handler cancelled.') + await sync_to_async(self.stop_event.wait, thread_sensitive=True)() finally: - _LOGGER.debug('faux zephyr handler done.') + channel_task.cancel() + zephyr_task.cancel() + await zephyr_task + await channel_task diff --git a/roost_backend/utils/__init__.py b/roost_backend/utils/__init__.py index c8928fd..e5f0ba6 100644 --- a/roost_backend/utils/__init__.py +++ b/roost_backend/utils/__init__.py @@ -1,5 +1,8 @@ +import functools import uuid +from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from . import kerberos @@ -24,6 +27,7 @@ def principal_to_user_socket_group_name(princ): # to actually maintain state, and the only purpose of opaque message # ids is to not trivially reveal how many messages you can't see. +@functools.lru_cache(maxsize=1024) def seal_message_id(msg_id: int) -> uuid.UUID: cipher = Cipher(algorithms.AES(MESSGE_ID_SEALING_KEY), modes.ECB()) enc = cipher.encryptor() @@ -31,8 +35,28 @@ def seal_message_id(msg_id: int) -> uuid.UUID: return uuid.UUID(bytes=cbytes) -def unseal_message_id(sealed_msg_id: uuid.UUID) -> int: +def unseal_message_id(sealed_msg_id) -> int: + if isinstance(sealed_msg_id, str): + sealed_msg_id = uuid.UUID(sealed_msg_id) cipher = Cipher(algorithms.AES(MESSGE_ID_SEALING_KEY), modes.ECB()) dec = cipher.decryptor() pbytes = dec.update(sealed_msg_id.bytes) + dec.finalize() return int.from_bytes(pbytes, 'big') + + +def notice_to_zuid_key(notice): + return ( + notice.uid.address, + notice.uid.time, + ) + + +def send_to_group(group_name, msg, wait_for_response=False): + channel_layer = get_channel_layer() + if wait_for_response: + channel_name = async_to_sync(channel_layer.new_channel)() + msg = dict(msg, _reply_to=channel_name) + async_to_sync(channel_layer.group_send)(group_name, msg) + if wait_for_response: + return async_to_sync(channel_layer.receive)(channel_name) + return None diff --git a/roost_backend/utils/kerberos.py b/roost_backend/utils/kerberos.py index a0af552..222f0c6 100644 --- a/roost_backend/utils/kerberos.py +++ b/roost_backend/utils/kerberos.py @@ -8,21 +8,35 @@ _LOGGER = logging.getLogger(__name__) + def principal_to_group_name(princ, group_type): b64_principal = base64.b64encode(princ.encode("utf-8")).decode("ascii") return f'_{group_type}_PRINC_{b64_principal.strip("=")}' -def initialize_memory_ccache(principal): +def initialize_memory_ccache(principal=None): + # Repoint to a new, in-memory credential cache. + os.environ['KRB5CCNAME'] = 'MEMORY:' + if principal: + ctx = k5.Context() + ccache = ctx.cc_default() + ccache.init_name(principal) + + +def initialize_memory_ccache_from_client_keytab(reinit=False): # Repoint to a new, in-memory credential cache. os.environ['KRB5CCNAME'] = 'MEMORY:' ctx = k5.Context() ccache = ctx.cc_default() - ccache.init_name(principal) + if reinit: + ccache.destroy() + ccache = ctx.cc_default() + keytab = ctx.kt_client_default() + ccache.init_from_keytab(keytab) -def add_credential_to_ccache(creds): - # pylint: disable=protected-access, too-many-locals +def add_credential_to_ccache(creds, princ=None): + # pylint: disable=protected-access, too-many-locals, too-many-statements # all this should be abstracted away somewhere else. # This may be leaky. Consider kdestroy/re-init/re-ping zephyr servers. def json_name_bits_to_princ(ctx, realm, name): @@ -44,6 +58,13 @@ def _b64(data): return base64.b64decode(data) return data + def verify_same_princ(client): + if princ: + client_name = client.unparse_name() + given_client_name = princ.encode('utf-8') + if client_name != given_client_name: + raise ValueError(f'Ticket for wrong client: {client_name} vs {given_client_name}') + ctx = k5.Context() ccache = ctx.cc_default() kcreds = kc.krb5_creds() @@ -52,6 +73,8 @@ def _b64(data): # Extract and massage the principals server = json_name_bits_to_princ(ctx, creds['srealm'], creds['sname']) client = json_name_bits_to_princ(ctx, creds['crealm'], creds['cname']) + verify_same_princ(client) + tkt_server = json_name_bits_to_princ(ctx, creds['ticket']['realm'], creds['ticket']['sname']) @@ -102,14 +125,27 @@ def _b64(data): # and finally, store the new cred in the ccache. k5.krb5_cc_store_cred(ctx._handle, ccache._handle, kcreds) -def get_zephyr_creds(realm): + +def _get_zephyr_creds(realm): + context = k5.Context() + ccache = context.cc_default() + principal = ccache.get_principal() + zephyr = context.build_principal(realm, ['zephyr', 'zephyr']) + return ccache.get_credentials(principal, zephyr, cache_only=True) + + +def get_zephyr_creds_dict(realm): try: - context = k5.Context() - ccache = context.cc_default() - principal = ccache.get_principal() - zephyr = context.build_principal(realm, ['zephyr', 'zephyr']) - creds = ccache.get_credentials(principal, zephyr, cache_only=True) + creds = _get_zephyr_creds(realm) creds_dict = creds.to_dict() return creds_dict except k5.Error: return {} + + +def have_valid_zephyr_creds(realm): + try: + creds = _get_zephyr_creds(realm) + return creds.is_valid() + except k5.Error: + return False diff --git a/roost_backend/views.py b/roost_backend/views.py index 44d9ee7..16362cc 100644 --- a/roost_backend/views.py +++ b/roost_backend/views.py @@ -9,7 +9,7 @@ from rest_framework.response import Response from rest_framework import generics, permissions, status -from . import filters, models, serializers +from . import filters, models, serializers, utils COMMON_DECORATORS = [vary_on_headers('Authorization'), never_cache] @@ -98,7 +98,7 @@ class SubscribeView(APIView): serializer_class = serializers.SubscriptionSerializer def post(self, request): - serializer = self.serializer_class(data=request.data, many=True, context={'request': request}) + serializer = self.serializer_class(data=request.data['subscriptions'], many=True, context={'request': request}) serializer.is_valid(raise_exception=True) vdata = serializer.validated_data @@ -135,7 +135,7 @@ def get_queryset(self): reverse = request.query_params.get('reverse', False) inclusive = request.query_params.get('inclusive', False) offset = request.query_params.get('offset') - limit = int(request.query_params.get('limit', 0)) + limit = int(request.query_params.get('count', 0)) # clamp limit if limit < 1: @@ -144,8 +144,7 @@ def get_queryset(self): limit = 100 if offset: - offset = int(offset) - # TODO: seal/unseal offset + offset = utils.unseal_message_id(offset) # TODO: Double check this if inclusive and reverse: qs = qs.filter(id__lte=offset) @@ -162,6 +161,12 @@ def get_queryset(self): qs = filters.MessageFilter(**request.query_params).apply_to_queryset(qs) return qs[:limit] + def list(self, request, *args, **kwargs): + return Response({ + 'messages': self.serializer_class(self.get_queryset(), many=True).data, + 'isDone': True, + }) + @method_decorator(COMMON_DECORATORS, name='dispatch') class MessageByTimeView(APIView): @@ -179,15 +184,17 @@ def get(self, request): @method_decorator(COMMON_DECORATORS, name='dispatch') class ZephyrCredsView(APIView): def get(self, request): - # This should find out if we need to refresh the zephyr - # credentials for this user and let them know. For now, the - # answer is no, everything is fine. + response = request.user.send_to_user_process({ + 'type': 'have_valid_credentials', + }, wait_for_response=True) + return Response({ - 'needsRefresh': False, + 'needsRefresh': not response['valid'], }) def post(self, request): # Accept, validate, and then promptly ignore credentials. + # If they were included, they auth layer pushed them to the user process. ret = request.zephyr_credentials is not None return Response({ 'refreshed': ret, @@ -199,7 +206,7 @@ class ZWriteView(APIView): serializer_class = serializers.OutgoingMessageSerializer def post(self, request): - serializer = self.serializer_class(data=request.data) + serializer = self.serializer_class(data=request.data['message']) serializer.is_valid(raise_exception=True) response = request.user.send_to_user_process({ 'type': 'zwrite', @@ -220,10 +227,8 @@ def post(self, request): # app.get('/v1/messages', requireUser # app.get('/v1/bytime', requireUser # app.post('/v1/zwrite', requireUser -# Stubbed: # app.get('/v1/zephyrcreds', requireUser # app.post('/v1/zephyrcreds', requireUser -# To do: # Also, a websocket at /v1/socket/websocket # message types: diff --git a/roost_ng/settings/gssapi.py b/roost_ng/settings/gssapi.py index 77969f1..76e9eec 100644 --- a/roost_ng/settings/gssapi.py +++ b/roost_ng/settings/gssapi.py @@ -9,3 +9,6 @@ if DEFAULT_KRB5_KTNAME and not os.environ.get('KRB5_KTNAME'): os.environ['KRB5_KTNAME'] = DEFAULT_KRB5_KTNAME + +if DEFAULT_SUBSCRIBER_KRB5_KEYTAB and not os.environ.get('KRB5_CLIENT_KTNAME'): + os.environ['KRB5_CLIENT_KTNAME'] = DEFAULT_SUBSCRIBER_KRB5_KEYTAB diff --git a/roost_ng/settings/logging.py b/roost_ng/settings/logging.py index 0341ea5..6117d82 100644 --- a/roost_ng/settings/logging.py +++ b/roost_ng/settings/logging.py @@ -19,5 +19,10 @@ 'level': os.getenv('DJANGO_LOG_LEVEL', 'INFO'), 'propagate': False, }, + 'roost_backend': { + 'handlers': ['console'], + 'level': os.getenv('ROOST_BACKEND_LOG_LEVEL', 'INFO'), + 'propagate': False, + }, }, }