Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import shutil
from pathlib import Path
from types import SimpleNamespace

import pytest

from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.sites.models import Site
from django.core.management import call_command

from rdmo.accounts.utils import set_group_permissions
Expand Down Expand Up @@ -75,3 +77,18 @@ def delete_all(*models):
for model in models:
model.objects.all().delete()
return delete_all


@pytest.fixture
def sites(settings):
Site.objects.clear_cache()

def activate(domain):
site = Site.objects.get(domain=domain)
settings.SITE_ID = site.id
Site.objects.clear_cache()
return site

yield SimpleNamespace(activate=activate)

Site.objects.clear_cache()
62 changes: 53 additions & 9 deletions rdmo/conditions/tests/test_viewset_condition_multisite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,56 @@

from rdmo.core.tests.constants import multisite_status_map as status_map
from rdmo.core.tests.constants import multisite_users as users
from rdmo.core.tests.utils import get_obj_perms_status_code

from ..models import Condition
from .test_viewset_condition import export_formats, urlnames

STATUS_CODES = {
'detail': {
'https://foo.com/terms/conditions/foo-condition': {
'user': 404, 'reviewer': 200, 'editor': 200,
'example-reviewer': 200, 'example-editor': 200, 'foo-user': 404,
'foo-reviewer': 200, 'foo-editor': 200, 'bar-user': 404,
'bar-reviewer': 404, 'bar-editor': 404, 'anonymous': 401,
},
'https://bar.com/terms/conditions/bar-condition': {
'user': 404, 'reviewer': 200, 'editor': 200,
'example-reviewer': 200, 'example-editor': 200, 'foo-user': 404,
'foo-reviewer': 404, 'foo-editor': 404, 'bar-user': 404,
'bar-reviewer': 200, 'bar-editor': 200, 'anonymous': 401,
},
},
'update': {
'https://foo.com/terms/conditions/foo-condition': {
'user': 404, 'reviewer': 403, 'editor': 200,
'example-reviewer': 404, 'example-editor': 404, 'foo-user': 404,
'foo-reviewer': 403, 'foo-editor': 200, 'bar-user': 404,
'bar-reviewer': 404, 'bar-editor': 404, 'anonymous': 401,
},
'https://bar.com/terms/conditions/bar-condition': {
'user': 404, 'reviewer': 403, 'editor': 200,
'example-reviewer': 404, 'example-editor': 404, 'foo-user': 404,
'foo-reviewer': 404, 'foo-editor': 404, 'bar-user': 404,
'bar-reviewer': 403, 'bar-editor': 200, 'anonymous': 401,
},
},
'delete': {
'https://foo.com/terms/conditions/foo-condition': {
'user': 404, 'reviewer': 403, 'editor': 204,
'example-reviewer': 404, 'example-editor': 404, 'foo-user': 404,
'foo-reviewer': 403, 'foo-editor': 204, 'bar-user': 404,
'bar-reviewer': 404, 'bar-editor': 404, 'anonymous': 401,
},
'https://bar.com/terms/conditions/bar-condition': {
'user': 404, 'reviewer': 403, 'editor': 204,
'example-reviewer': 404, 'example-editor': 404, 'foo-user': 404,
'foo-reviewer': 404, 'foo-editor': 404, 'bar-user': 404,
'bar-reviewer': 403, 'bar-editor': 204, 'anonymous': 401,
},
},
}



@pytest.mark.parametrize('username,password', users)
def test_list(db, client, username, password):
Expand Down Expand Up @@ -54,9 +99,9 @@ def test_detail(db, client, username, password):
for instance in instances:
url = reverse(urlnames['detail'], args=[instance.pk])
response = client.get(url)
assert response.status_code == get_obj_perms_status_code(
instance, username, 'detail'
), (response.json(), instance.editors.all())
assert response.status_code == STATUS_CODES['detail'].get(instance.uri, status_map['detail'])[username], (
response.json(), instance.editors.all()
)


@pytest.mark.parametrize('username,password', users)
Expand Down Expand Up @@ -231,9 +276,9 @@ def test_update(db, client, username, password):
'target_option': instance.target_option.pk if instance.target_option else None
}
response = client.put(url, data, content_type='application/json')
assert response.status_code == get_obj_perms_status_code(
instance, username, 'update'
), (response.json(), instance.editors.all())
assert response.status_code == STATUS_CODES['update'].get(instance.uri, status_map['update'])[username], (
response.json(), instance.editors.all()
)


@pytest.mark.parametrize('username,password', users)
Expand All @@ -242,10 +287,9 @@ def test_delete(db, client, username, password):
instances = Condition.objects.all()

for instance in instances:
editors = list(instance.editors.values_list('domain', flat=True))
url = reverse(urlnames['detail'], args=[instance.pk])
response = client.delete(url)
assert response.status_code == get_obj_perms_status_code(instance, username, 'delete', editors=editors)
assert response.status_code == STATUS_CODES['delete'].get(instance.uri, status_map['delete'])[username]


@pytest.mark.parametrize('username,password', users)
Expand Down
5 changes: 5 additions & 0 deletions rdmo/core/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,8 @@ def has_permission(self, request, view) -> bool:

# the viewset needs to set permission_required
return request.user.has_perm(view.permission_required)


def get_object_permission(model_or_instance, action):
Comment thread
jochenklar marked this conversation as resolved.
model = model_or_instance if isinstance(model_or_instance, type) else model_or_instance._meta.model
return f'{model._meta.app_label}.{action}_{model._meta.model_name}_object'
87 changes: 58 additions & 29 deletions rdmo/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from django.db.models import Max

from rest_framework import serializers
from rest_framework.exceptions import PermissionDenied
from rest_framework.utils import model_meta

from rdmo.core.permissions import get_object_permission
from rdmo.core.utils import get_language_warning, get_languages, markdown2html

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,6 +76,11 @@ def __init__(self, *args, **kwargs):

class ThroughModelSerializerMixin:

def validate(self, attrs):
attrs = super().validate(attrs)
self.check_parent_object_permissions(attrs)
return attrs

def create(self, validated_data):
parent_fields = self.get_parent_fields(validated_data)
through_fields = self.get_through_fields(validated_data)
Expand Down Expand Up @@ -143,22 +150,9 @@ def set_through_fields(self, instance, through_fields):
return instance

def get_parent_fields(self, validated_data):
try:
self.Meta.parent_fields # noqa: B018
except AttributeError:
return None

model_info = model_meta.get_field_info(self.Meta.model)

parent_fields = {}
for field_name, _source_name, _target_name, through_name in self.Meta.parent_fields:
parent_model = model_info.reverse_relations[field_name].related_model
parent_model_info = model_meta.get_field_info(parent_model)

through_model = parent_model_info.reverse_relations[through_name].related_model

parent_fields[field_name] = (through_model, validated_data.pop(field_name, None))

parent_fields = self.resolve_parent_fields(validated_data)
for field_name, *_ in parent_fields:
validated_data.pop(field_name, None)
return parent_fields

def set_parent_fields(self, instance, parent_fields):
Expand All @@ -167,13 +161,11 @@ def set_parent_fields(self, instance, parent_fields):
except AttributeError:
return instance

for field_name, source_name, target_name, through_name in self.Meta.parent_fields:
through_model, validated_data = parent_fields[field_name]

if validated_data is None:
for _field_name, source_name, target_name, through_name, through_model, parents in parent_fields:
if parents is None:
continue

for parent in validated_data:
for parent in parents:
order = (getattr(parent, through_name).aggregate(order=Max('order')).get('order') or 0) + 1
through_model(**{
source_name: parent,
Expand All @@ -183,6 +175,50 @@ def set_parent_fields(self, instance, parent_fields):

return instance

def resolve_parent_fields(self, data):
Comment thread
MyPyDavid marked this conversation as resolved.
try:
self.Meta.parent_fields # noqa: B018
except AttributeError:
return []

model_info = model_meta.get_field_info(self.Meta.model)

parent_fields = []
for field_name, source_name, target_name, through_name in self.Meta.parent_fields:
parent_model = model_info.reverse_relations[field_name].related_model
parent_model_info = model_meta.get_field_info(parent_model)

through_model = parent_model_info.reverse_relations[through_name].related_model

parent_fields.append((
field_name,
source_name,
target_name,
through_name,
through_model,
data.get(field_name),
))

return parent_fields

def check_parent_object_permissions(self, attrs):
# Parent fields create through-model rows on existing parent elements when this
# serializer creates a new child element. That effectively changes the
# parent element (for example adding a section to an existing catalog), so
# require object-level change permission for each supplied parent.
if self.instance is not None:
return

request = self.context.get('request')
if request is None:
return

for *_, parents in self.resolve_parent_fields(attrs):
Comment thread
MyPyDavid marked this conversation as resolved.
for parent in parents or []:
permission = get_object_permission(parent, 'change')
if not request.user.has_perm(permission, parent):
raise PermissionDenied()


class ElementModelSerializerMixin(serializers.ModelSerializer):

Expand Down Expand Up @@ -217,19 +253,12 @@ class ReadOnlyObjectPermissionSerializerMixin:

OBJECT_PERMISSION_ACTION_NAMES = ('change', 'delete')

@staticmethod
def construct_object_permission(model, action_name: str) -> str:
model_app_label = model._meta.app_label
model_name = model._meta.model_name
perm = f'{model_app_label}.{action_name}_{model_name}_object'
return perm

def get_read_only(self, obj) -> bool:
request = self.context.get('request')
if request is None:
return False
user = request.user
perms = (self.construct_object_permission(self.Meta.model, action_name)
perms = (get_object_permission(self.Meta.model, action_name)
for action_name in self.OBJECT_PERMISSION_ACTION_NAMES)
return not all(user.has_perm(perm, obj) for perm in perms)

Expand Down
Loading
Loading