|
29 | 29 | GroupOfQuestionsUpdateSerializer, |
30 | 30 | ) |
31 | 31 | from questions.services import get_aggregated_forecasts_for_questions |
| 32 | +from questions.types import AggregationMethod |
32 | 33 | from users.models import User |
33 | 34 | from utils.dtypes import flatten |
34 | 35 | from utils.serializers import SerializerKeyLookupMixin |
@@ -594,3 +595,90 @@ class PostRelatedArticleSerializer(serializers.ModelSerializer): |
594 | 595 | class Meta: |
595 | 596 | model = ITNArticle |
596 | 597 | fields = ("id", "title", "url", "favicon_url", "created_at", "media_label") |
| 598 | + |
| 599 | + |
| 600 | +class DownloadDataSerializer(serializers.Serializer): |
| 601 | + sub_question = serializers.IntegerField(required=False) |
| 602 | + aggregation_methods = serializers.CharField(required=False) |
| 603 | + user_ids = serializers.CharField(required=False, allow_null=True) |
| 604 | + include_comments = serializers.BooleanField(required=False, default=False) |
| 605 | + include_scores = serializers.BooleanField(required=False, default=False) |
| 606 | + include_bots = serializers.BooleanField(required=False, allow_null=True) |
| 607 | + minimize = serializers.BooleanField(required=False, default=True) |
| 608 | + |
| 609 | + def validate_aggregation_methods(self, value): |
| 610 | + if value is None: |
| 611 | + return |
| 612 | + user: User = self.context["user"] |
| 613 | + if value == "all": |
| 614 | + aggregation_methods = [ |
| 615 | + AggregationMethod.RECENCY_WEIGHTED, |
| 616 | + AggregationMethod.UNWEIGHTED, |
| 617 | + AggregationMethod.METACULUS_PREDICTION, |
| 618 | + ] |
| 619 | + if user.is_staff: |
| 620 | + aggregation_methods.append(AggregationMethod.SINGLE_AGGREGATION) |
| 621 | + return aggregation_methods |
| 622 | + methods = value.split(",") |
| 623 | + invalid_methods = [ |
| 624 | + method for method in methods if method not in AggregationMethod.values |
| 625 | + ] |
| 626 | + if invalid_methods: |
| 627 | + raise serializers.ValidationError( |
| 628 | + f"Invalid aggregation method(s): {', '.join(invalid_methods)}" |
| 629 | + ) |
| 630 | + if not user.is_staff: |
| 631 | + methods = [ |
| 632 | + method |
| 633 | + for method in methods |
| 634 | + if method != AggregationMethod.SINGLE_AGGREGATION |
| 635 | + ] |
| 636 | + return methods |
| 637 | + |
| 638 | + def validate_user_ids(self, value): |
| 639 | + if not value: |
| 640 | + return value |
| 641 | + user_ids = value.split(",") |
| 642 | + if not all(user_id.isdigit() for user_id in user_ids): |
| 643 | + raise serializers.ValidationError( |
| 644 | + "Invalid user_ids. Must be a comma-separated list of integers." |
| 645 | + ) |
| 646 | + if not self.context["can_view_private_data"]: |
| 647 | + raise serializers.ValidationError( |
| 648 | + "Current user cannot view user-specific data. " |
| 649 | + "Please remove user_ids parameter." |
| 650 | + ) |
| 651 | + uids = [int(user_id) for user_id in user_ids] |
| 652 | + return uids |
| 653 | + |
| 654 | + def validate(self, attrs): |
| 655 | + # Check if there are any unexpected fields |
| 656 | + allowed_fields = { |
| 657 | + "sub_question", |
| 658 | + "aggregation_methods", |
| 659 | + "user_ids", |
| 660 | + "include_comments", |
| 661 | + "include_scores", |
| 662 | + "include_bots", |
| 663 | + "minimize", |
| 664 | + } |
| 665 | + input_fields = set(self.initial_data.keys()) |
| 666 | + unexpected_fields = input_fields - allowed_fields |
| 667 | + if unexpected_fields: |
| 668 | + raise ValidationError(f"Unexpected fields: {', '.join(unexpected_fields)}") |
| 669 | + |
| 670 | + # Aggregation validation logic |
| 671 | + aggregation_methods = attrs.get("aggregation_methods") |
| 672 | + user_ids = attrs.get("user_ids") |
| 673 | + include_bots = attrs.get("include_bots") |
| 674 | + minimize = attrs.get("minimize", True) |
| 675 | + |
| 676 | + if not aggregation_methods and ( |
| 677 | + user_ids is not None or include_bots is not None or not minimize |
| 678 | + ): |
| 679 | + raise serializers.ValidationError( |
| 680 | + "If user_ids, include_bots, or minimize is set, " |
| 681 | + "aggregation_methods must also be set." |
| 682 | + ) |
| 683 | + |
| 684 | + return attrs |
0 commit comments