-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathviews.py
267 lines (211 loc) · 9.31 KB
/
views.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import collections.abc
import datetime
from asgiref.sync import async_to_sync
import channels.layers
from django.conf import settings
from django.db import transaction
from django.http import HttpResponse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import never_cache
from django.views.decorators.vary import vary_on_headers
import gssapi
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import generics, permissions, status
from . import filters, models, serializers, utils
COMMON_DECORATORS = [vary_on_headers('Authorization'), never_cache]
@method_decorator(never_cache, name='dispatch')
class AuthView(APIView):
authentication_classes = []
permission_classes = [permissions.AllowAny]
serializer_class = serializers.AuthSerializer
def handle_exception(self, exc):
if isinstance(exc, gssapi.exceptions.GSSError):
return Response(f'{type(exc).__name__}: {exc}', status=status.HTTP_400_BAD_REQUEST)
return super().handle_exception(exc)
@staticmethod
async def wait_for_user_process(principal):
channel_layer = channels.layers.get_channel_layer()
channel_name = utils.principal_to_user_subscriber_announce_channel(principal)
while True:
message = await channel_layer.receive(channel_name)
if message.get('principal') == principal:
return
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
server_creds = gssapi.Credentials(usage='accept')
ctx = gssapi.SecurityContext(creds=server_creds, usage='accept')
gss_token = ctx.step(serializer.validated_data['token'])
if not ctx.complete:
return Response("Roost-ng does not support multi-step GSS handshakes.", status=status.HTTP_400_BAD_REQUEST)
principal = str(ctx.initiator_name)
user = models.User.objects.filter(principal=principal).first()
if user is None and serializer.validated_data['create_user']:
if ((settings.ROOST_ALLOW_USER_CREATION is True # explicit check for bool here
and principal not in settings.ROOST_USER_CREATION_DENYLIST)
or (isinstance(settings.ROOST_ALLOW_USER_CREATION, collections.abc.Container)
and principal in settings.ROOST_ALLOW_USER_CREATION)):
with transaction.atomic():
user, created = models.User.objects.get_or_create(principal=principal)
if created:
async_to_sync(self.wait_for_user_process)(user.principal)
else:
return HttpResponse(f'User creation not allowed for user {principal}',
status=status.HTTP_403_FORBIDDEN)
if user is None:
return HttpResponse('User does not exist', status=status.HTTP_403_FORBIDDEN)
resp = self.serializer_class({
'gss_token': gss_token,
**user.get_auth_token_dict(),
})
return Response(resp.data)
@method_decorator(COMMON_DECORATORS, name='dispatch')
class PingView(APIView):
def get(self, request):
return Response({'pong': 1})
@method_decorator(COMMON_DECORATORS, name='dispatch')
class InfoView(APIView):
serializer_class = serializers.InfoSerializer
def get(self, request):
serializer = self.serializer_class(self.request.user)
return Response(serializer.data)
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
vdata = serializer.validated_data
updated = False
with transaction.atomic():
user = models.User.objects.select_for_update().filter(pk=self.request.user.pk).first()
if user.info_version == vdata['expected_version']:
user.info_version += 1
user.info = vdata['info']
user.save()
updated = True
if updated:
return Response({'updated': True})
return Response({'updated': False, **self.serializer_class(user).data})
@method_decorator(COMMON_DECORATORS, name='dispatch')
class SubscriptionView(generics.ListAPIView):
serializer_class = serializers.SubscriptionSerializer
def get_queryset(self):
return self.request.user.subscription_set
@method_decorator(COMMON_DECORATORS, name='dispatch')
class SubscribeView(APIView):
serializer_class = serializers.SubscriptionSerializer
def post(self, request):
serializer = self.serializer_class(data=request.data['subscriptions'], many=True, context={'request': request})
serializer.is_valid(raise_exception=True)
vdata = serializer.validated_data
user = self.request.user
subs = user.add_subscriptions(vdata)
serializer = self.serializer_class(subs, many=True)
return Response(serializer.data)
@method_decorator(COMMON_DECORATORS, name='dispatch')
class UnsubscribeView(APIView):
serializer_class = serializers.SubscriptionSerializer
def post(self, request):
serializer = self.serializer_class(data=request.data['subscription'])
serializer.is_valid(raise_exception=True)
vdata = serializer.validated_data
user = self.request.user
sub, _removed = user.remove_subscription(vdata['zclass'], vdata['zinstance'], vdata['zrecipient'])
serializer = self.serializer_class(sub)
return Response(serializer.data)
@method_decorator(COMMON_DECORATORS, name='dispatch')
class MessageView(generics.ListAPIView):
serializer_class = serializers.MessageSerializer
@classmethod
def prepare_response_payload(cls, qs, params):
reverse = int(params.get('reverse', False))
inclusive = int(params.get('inclusive', False))
offset = params.get('offset')
limit = int(params.get('count', 0))
# clamp limit
if limit < 1:
limit = 1
elif limit > settings.ROOST_MESSAGES_MAX_LIMIT:
limit = settings.ROOST_MESSAGES_MAX_LIMIT
if offset:
offset = utils.unseal_message_id(offset)
# TODO: Double check this
if inclusive and reverse:
qs = qs.filter(id__lte=offset)
elif inclusive:
qs = qs.filter(id__gte=offset)
elif reverse:
qs = qs.filter(id__lt=offset)
else:
qs = qs.filter(id__gt=offset)
if reverse:
qs = qs.reverse()
qs = filters.MessageFilter(**params).apply_to_queryset(qs)
# We get limit+1 to detect if we have more to fetch
serializer = cls.serializer_class(qs[:limit+1], many=True)
payload = {
'messages': serializer.data[:limit],
'isDone': len(serializer.data) <= limit,
}
return payload
def list(self, request, *args, **kwargs):
qs = request.user.message_set.all()
payload = self.prepare_response_payload(qs, request.query_params)
return Response(payload)
@method_decorator(COMMON_DECORATORS, name='dispatch')
class MessageByTimeView(APIView):
def get(self, request):
time = request.query_params.get('time')
if time is None:
return Response('time not specified', status=status.HTTP_400_BAD_REQUEST)
time = datetime.datetime.fromtimestamp(int(time) / 1000, datetime.timezone.utc)
msg = request.user.message_set.filter(receive_time__gte=time).order_by('receive_time').first()
return Response({
'id': msg and msg.id
})
@method_decorator(COMMON_DECORATORS, name='dispatch')
class ZephyrCredsView(APIView):
def get(self, request):
response = request.user.send_to_user_subscriber({
'type': 'have_valid_credentials',
}, wait_for_response=True)
return Response({
'needsRefresh': not response['valid'],
})
def post(self, request):
# Accept, validate, and then promptly ignore credentials.
# If they were included, the auth layer pushed them to the user process.
ret = request.zephyr_credentials is not None
return Response({
'refreshed': ret,
})
@method_decorator(COMMON_DECORATORS, name='dispatch')
class ZWriteView(APIView):
serializer_class = serializers.OutgoingMessageSerializer
def post(self, request):
serializer = self.serializer_class(data=request.data['message'])
serializer.is_valid(raise_exception=True)
response = request.user.send_to_user_subscriber({
'type': 'zwrite',
'message': serializer.validated_data,
}, wait_for_response=True)
return Response(response)
# Roost's endpoints:
# Done:
# app.post('/v1/auth
# app.get('/v1/ping', requireUser
# app.get('/v1/info', requireUser
# app.post('/v1/info', requireUser
# app.get('/v1/subscriptions', requireUser
# app.post('/v1/subscribe', requireUser
# app.post('/v1/unsubscribe', requireUser
# app.get('/v1/messages', requireUser
# app.get('/v1/bytime', requireUser
# app.post('/v1/zwrite', requireUser
# app.get('/v1/zephyrcreds', requireUser
# app.post('/v1/zephyrcreds', requireUser
# Also, a websocket at /v1/socket/websocket
# message types:
# - ping (-> pong)
# - new-tail
# - extend-tail
# - close-tail