Skip to content

Commit 2d8971a

Browse files
authored
CSV download improvements (#1775)
Change endpoint to download-data from download-csv add endpoint to projects Adds a Whitelist to allow for specific non-admin users to access individual forecast-level data
1 parent 94501fa commit 2d8971a

File tree

19 files changed

+741
-315
lines changed

19 files changed

+741
-315
lines changed

docs/openapi.yml

Lines changed: 265 additions & 98 deletions
Large diffs are not rendered by default.

front_end/messages/en.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,15 +588,15 @@
588588
"boost": "Boost",
589589
"bury": "Bury",
590590
"duplicate": "Duplicate",
591-
"downloadCSV": "Download CSV",
591+
"downloadQuestionData": "Download Question data",
592592
"partnersUseForecasts": "Our partners use Metaculus forecasts to gain insight into complex challenges and decisions.",
593593
"learnHowYouCanPartner": "Learn how you can partner with Metaculus to host a forecasting tournament, collaborate on forecasting research, hire our <link>Pro Forecasters</link> for custom projects, or set up a private forecasting space for your organization.",
594594
"workingWithNonProfits": "We love working with non-profits in our core focus areas. If you’re a non-profit seeking forecasting and modeling capacity or are interested in collaborating on a grant, please reach out!",
595595
"reachOutToLearnMore": "Please reach out to learn more about working with Metaculus or with any questions, comments or concerns. We’ll get back to you soon.",
596596
"feelFreeToJustSayHello": "Feel free to just say hello, too - we love hearing from the forecasting community!",
597597
"contentBoosted": "Content boosted! value {score} activity. Total boost score for the week: {score_total}",
598598
"contentBuried": "Content buried! value {score} activity. Total boost score for the week: {score_total}",
599-
"downloadCSVError": "Error downloading CSV",
599+
"downloadQuestionDataError": "Error downloading Question data",
600600
"cpRevealTime": "CP Reveal Time",
601601
"openTime": "Open Time",
602602
"Question": "Question",

front_end/src/app/(main)/questions/actions.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ export async function changePostSubscriptions(
351351
return response;
352352
}
353353

354-
export async function getPostCSVData(postId: number) {
355-
const blob = await PostsApi.getPostCSVData(postId);
354+
export async function getPostZipData(postId: number) {
355+
const blob = await PostsApi.getPostZipData(postId);
356356
const arrayBuffer = await blob.arrayBuffer();
357357
const base64String = Buffer.from(arrayBuffer).toString("base64");
358358

front_end/src/components/post_actions/post_dropdown_menu.tsx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import toast from "react-hot-toast";
88

99
import {
1010
changePostActivityBoost,
11-
getPostCSVData,
11+
getPostZipData,
1212
} from "@/app/(main)/questions/actions";
1313
import Button from "@/components/ui/button";
1414
import DropdownMenu, { MenuItemProps } from "@/components/ui/dropdown_menu";
@@ -50,22 +50,22 @@ export const PostDropdownMenu: FC<Props> = ({ post }) => {
5050
return `/questions/create/question?mode=create&post_id=${post.id}`;
5151
};
5252

53-
const handleDownloadCSV = async () => {
53+
const handleDownloadQuestionData = async () => {
5454
try {
55-
const base64 = await getPostCSVData(post.id);
55+
const base64 = await getPostZipData(post.id);
5656
const blob = base64ToBlob(base64);
57-
const filename = `${post.url_title.replaceAll(" ", "_")}.csv`;
57+
const filename = `${post.url_title.replaceAll(" ", "_")}.zip`;
5858
saveAs(blob, filename);
5959
} catch (error) {
60-
toast.error(t("downloadCSVError") + error);
60+
toast.error(t("downloadQuestionDataError") + error);
6161
}
6262
};
6363

6464
const menuItems: MenuItemProps[] = [
6565
{
66-
id: "downloadCSV",
67-
name: t("downloadCSV"),
68-
onClick: handleDownloadCSV,
66+
id: "downloadQuestionData",
67+
name: t("downloadQuestionData"),
68+
onClick: handleDownloadQuestionData,
6969
},
7070
];
7171
if (user?.is_superuser) {

front_end/src/services/posts.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ class PostsApi {
221221
return await get<{ id: number; post_slug: string }>("/posts/random/");
222222
}
223223

224-
static async getPostCSVData(postId: number): Promise<Blob> {
225-
return await get<Blob>(`/posts/${postId}/download-csv/`);
224+
static async getPostZipData(postId: number): Promise<Blob> {
225+
return await get<Blob>(`/posts/${postId}/download-data/`);
226226
}
227227
}
228228

front_end/src/utils/fetch.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ const handleResponse = async <T>(response: Response): Promise<T> => {
7474
// Check the content type to determine how to process the response
7575
const contentType = response.headers.get("content-type");
7676

77-
if (contentType && contentType.includes("text/csv")) {
78-
// If the response is a CSV, return it as a Blob
77+
if (contentType && contentType.includes("application/zip")) {
78+
// If the response is a ZIP, return it as a Blob
7979
return response.blob() as unknown as T;
8080
}
8181

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Generated by Django 5.1.4 on 2024-12-27 18:53
2+
3+
import django.db.models.deletion
4+
import django.utils.timezone
5+
from django.conf import settings
6+
from django.db import migrations, models
7+
8+
9+
class Migration(migrations.Migration):
10+
11+
dependencies = [
12+
("misc", "0002_initial"),
13+
("posts", "0012_alter_post_default_project"),
14+
("projects", "0010_remove_project_add_posts_to_main_feed_and_more"),
15+
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
16+
]
17+
18+
operations = [
19+
migrations.CreateModel(
20+
name="WhitelistUser",
21+
fields=[
22+
(
23+
"id",
24+
models.BigAutoField(
25+
auto_created=True,
26+
primary_key=True,
27+
serialize=False,
28+
verbose_name="ID",
29+
),
30+
),
31+
(
32+
"created_at",
33+
models.DateTimeField(
34+
default=django.utils.timezone.now, editable=False
35+
),
36+
),
37+
("edited_at", models.DateTimeField(editable=False, null=True)),
38+
(
39+
"post",
40+
models.ForeignKey(
41+
help_text="Optional. If provided, this allows the user to download user-level data for the post. If neither project nor post is set, the user is whitelisted for all data.",
42+
null=True,
43+
on_delete=django.db.models.deletion.CASCADE,
44+
related_name="whitelists",
45+
to="posts.post",
46+
),
47+
),
48+
(
49+
"project",
50+
models.ForeignKey(
51+
help_text="Optional. If provided, this allows the user to download user-level data for the project. If neither project nor post is set, the user is whitelisted for all data.",
52+
null=True,
53+
on_delete=django.db.models.deletion.CASCADE,
54+
related_name="whitelists",
55+
to="projects.project",
56+
),
57+
),
58+
(
59+
"user",
60+
models.ForeignKey(
61+
on_delete=django.db.models.deletion.CASCADE,
62+
related_name="whitelists",
63+
to=settings.AUTH_USER_MODEL,
64+
),
65+
),
66+
],
67+
options={
68+
"abstract": False,
69+
},
70+
),
71+
]

misc/models.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from pgvector.django import VectorField
33

44
from utils.models import TimeStampedModel
5+
from users.models import User
6+
from projects.models import Project
7+
from posts.models import Post
58

69

710
class ITNArticle(TimeStampedModel):
@@ -30,9 +33,33 @@ class Bulletin(TimeStampedModel):
3033

3134
class BulletinViewedBy(TimeStampedModel):
3235
bulletin = models.ForeignKey(Bulletin, on_delete=models.CASCADE)
33-
user = models.ForeignKey("users.User", on_delete=models.CASCADE)
36+
user = models.ForeignKey(User, on_delete=models.CASCADE)
3437

3538

3639
# TODO: index new posts
3740
# TODO: ensure we sync PostITNArticle new articles only
3841
# TODO: create a sync command + cron job
42+
43+
44+
class WhitelistUser(TimeStampedModel):
45+
"""Whitelist for users for permission to download user-level data"""
46+
47+
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="whitelists")
48+
project = models.ForeignKey(
49+
Project,
50+
null=True,
51+
on_delete=models.CASCADE,
52+
related_name="whitelists",
53+
help_text="Optional. If provided, this allows the user to download user-level "
54+
"data for the project. If neither project nor post is set, the user is "
55+
"whitelisted for all data.",
56+
)
57+
post = models.ForeignKey(
58+
Post,
59+
null=True,
60+
on_delete=models.CASCADE,
61+
related_name="whitelists",
62+
help_text="Optional. If provided, this allows the user to download user-level "
63+
"data for the post. If neither project nor post is set, the user is "
64+
"whitelisted for all data.",
65+
)

posts/admin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def export_selected_posts_data(self, request, queryset: QuerySet[Post]):
6464

6565
questions = Question.objects.filter(related_posts__post__in=queryset).distinct()
6666

67-
data = export_data_for_questions(questions)
67+
data = export_data_for_questions(questions, True, True, True)
6868
if data is None:
6969
self.message_user(request, "No questions selected.")
7070
return

posts/serializers.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
GroupOfQuestionsUpdateSerializer,
3030
)
3131
from questions.services import get_aggregated_forecasts_for_questions
32+
from questions.types import AggregationMethod
3233
from users.models import User
3334
from utils.dtypes import flatten
3435
from utils.serializers import SerializerKeyLookupMixin
@@ -594,3 +595,90 @@ class PostRelatedArticleSerializer(serializers.ModelSerializer):
594595
class Meta:
595596
model = ITNArticle
596597
fields = ("id", "title", "url", "favicon_url", "created_at", "media_label")
598+
599+
600+
class DownloadDataSerializer(serializers.Serializer):
601+
sub_question = serializers.IntegerField(required=False)
602+
aggregation_methods = serializers.CharField(required=False)
603+
user_ids = serializers.CharField(required=False, allow_null=True)
604+
include_comments = serializers.BooleanField(required=False, default=False)
605+
include_scores = serializers.BooleanField(required=False, default=False)
606+
include_bots = serializers.BooleanField(required=False, allow_null=True)
607+
minimize = serializers.BooleanField(required=False, default=True)
608+
609+
def validate_aggregation_methods(self, value):
610+
if value is None:
611+
return
612+
user: User = self.context["user"]
613+
if value == "all":
614+
aggregation_methods = [
615+
AggregationMethod.RECENCY_WEIGHTED,
616+
AggregationMethod.UNWEIGHTED,
617+
AggregationMethod.METACULUS_PREDICTION,
618+
]
619+
if user.is_staff:
620+
aggregation_methods.append(AggregationMethod.SINGLE_AGGREGATION)
621+
return aggregation_methods
622+
methods = value.split(",")
623+
invalid_methods = [
624+
method for method in methods if method not in AggregationMethod.values
625+
]
626+
if invalid_methods:
627+
raise serializers.ValidationError(
628+
f"Invalid aggregation method(s): {', '.join(invalid_methods)}"
629+
)
630+
if not user.is_staff:
631+
methods = [
632+
method
633+
for method in methods
634+
if method != AggregationMethod.SINGLE_AGGREGATION
635+
]
636+
return methods
637+
638+
def validate_user_ids(self, value):
639+
if not value:
640+
return value
641+
user_ids = value.split(",")
642+
if not all(user_id.isdigit() for user_id in user_ids):
643+
raise serializers.ValidationError(
644+
"Invalid user_ids. Must be a comma-separated list of integers."
645+
)
646+
if not self.context["can_view_private_data"]:
647+
raise serializers.ValidationError(
648+
"Current user cannot view user-specific data. "
649+
"Please remove user_ids parameter."
650+
)
651+
uids = [int(user_id) for user_id in user_ids]
652+
return uids
653+
654+
def validate(self, attrs):
655+
# Check if there are any unexpected fields
656+
allowed_fields = {
657+
"sub_question",
658+
"aggregation_methods",
659+
"user_ids",
660+
"include_comments",
661+
"include_scores",
662+
"include_bots",
663+
"minimize",
664+
}
665+
input_fields = set(self.initial_data.keys())
666+
unexpected_fields = input_fields - allowed_fields
667+
if unexpected_fields:
668+
raise ValidationError(f"Unexpected fields: {', '.join(unexpected_fields)}")
669+
670+
# Aggregation validation logic
671+
aggregation_methods = attrs.get("aggregation_methods")
672+
user_ids = attrs.get("user_ids")
673+
include_bots = attrs.get("include_bots")
674+
minimize = attrs.get("minimize", True)
675+
676+
if not aggregation_methods and (
677+
user_ids is not None or include_bots is not None or not minimize
678+
):
679+
raise serializers.ValidationError(
680+
"If user_ids, include_bots, or minimize is set, "
681+
"aggregation_methods must also be set."
682+
)
683+
684+
return attrs

0 commit comments

Comments
 (0)