diff --git a/docs/openapi.yml b/docs/openapi.yml index 9728f8f31a..4a74dc7ae2 100644 --- a/docs/openapi.yml +++ b/docs/openapi.yml @@ -534,6 +534,40 @@ components: items: type: string description: "List of options for multiple_choice questions" + example: + - "Democratic" + - "Republican" + - "Libertarian" + - "Green" + - "Other" + all_options_ever: + type: array + items: + type: string + description: "List of all options ever for multiple_choice questions" + example: + - "Democratic" + - "Republican" + - "Libertarian" + - "Green" + - "Blue" + - "Other" + options_history: + type: array + description: "List of [iso format time, options] pairs for multiple_choice questions" + items: + type: array + items: + oneOf: + - type: string + description: "ISO 8601 timestamp when the options became active" + - type: array + items: + type: string + description: "Options list active from this timestamp onward" + example: + - ["0001-01-01T00:00:00", ["a", "b", "c", "other"]] + - ["2026-10-22T16:00:00", ["a", "b", "c", "d", "other"]] status: type: string enum: [ upcoming, open, closed, resolved ] @@ -1346,6 +1380,7 @@ paths: actual_close_time: "2020-11-01T00:00:00Z" type: "numeric" options: null + options_history: null status: "resolved" resolution: "77289125.94957079" resolution_criteria: "Resolution Criteria Copy" @@ -1519,6 +1554,7 @@ paths: actual_close_time: "2015-12-15T03:34:00Z" type: "binary" options: null + options_history: null status: "resolved" possibilities: type: "binary" @@ -1588,6 +1624,16 @@ paths: - "Libertarian" - "Green" - "Other" + all_options_ever: + - "Democratic" + - "Republican" + - "Libertarian" + - "Green" + - "Blue" + - "Other" + options_history: + - ["0001-01-01T00:00:00", ["Democratic", "Republican", "Libertarian", "Other"]] + - ["2026-10-22T16:00:00", ["Democratic", "Republican", "Libertarian", "Green", "Other"]] status: "open" possibilities: { } resolution: null diff --git a/front_end/messages/cs.json b/front_end/messages/cs.json index 30591a4997..241da38cad 100644 --- a/front_end/messages/cs.json +++ b/front_end/messages/cs.json @@ -1823,5 +1823,12 @@ "tournamentsInfoTitle": "Jsme nepredikční trh. Můžete se účastnit zdarma a vyhrát peněžní ceny za přesnost.", "tournamentsInfoScoringLink": "Co jsou předpovídací skóre?", "tournamentsInfoPrizesLink": "Jak jsou rozdělovány ceny?", + "dismiss": "Zavřít", + "gracePeriodTooltip": "Pokud neaktualizujete své předpovědi před koncem této lhůty, vaše stávající předpovědi budou automaticky staženy.", + "newOptionsAddedPlural": "Tyto možnosti byly nedávno přidány, prosím upravte své předpovědi odpovídajícím způsobem.", + "newOptionsAddedSingular": "Nedávno byla přidána nová možnost, prosím upravte své předpovědi odpovídajícím způsobem.", + "showNewOptions": "Zobrazit nové možnosti", + "showNewOption": "Zobrazit novou možnost", + "timeRemaining": "Zbývající čas", "othersCount": "Ostatní ({count})" } diff --git a/front_end/messages/en.json b/front_end/messages/en.json index e20265c86a..f9a316c2bc 100644 --- a/front_end/messages/en.json +++ b/front_end/messages/en.json @@ -19,6 +19,13 @@ "withdraw": "Withdraw", "withdrawForecast": "Withdraw Forecast", "withdrawAll": "Withdraw All", + "dismiss": "Dismiss", + "gracePeriodTooltip": "If you don't update your forecasts before the grace period ends, your existing forecasts will be automatically withdrawn.", + "newOptionsAddedPlural": "These options were recently added, please adjust your forecast(s) accordingly.", + "newOptionsAddedSingular": "A new option was recently added, please adjust your forecasts accordingly.", + "showNewOptions": "Show New Options", + "showNewOption": "Show New Option", + "timeRemaining": "Time remaining", "saveChange": "Save Change", "reaffirm": "Reaffirm", "reaffirmAll": "Reaffirm All", diff --git a/front_end/messages/es.json b/front_end/messages/es.json index a6c2788afa..f99339f7f5 100644 --- a/front_end/messages/es.json +++ b/front_end/messages/es.json @@ -1823,5 +1823,12 @@ "tournamentsInfoTitle": "Nosotros no somos un mercado de predicciones. Puedes participar gratis y ganar premios en efectivo por ser preciso.", "tournamentsInfoScoringLink": "¿Qué son las puntuaciones de predicción?", "tournamentsInfoPrizesLink": "¿Cómo se distribuyen los premios?", + "dismiss": "Descartar", + "gracePeriodTooltip": "Si no actualiza sus pronósticos antes de que termine el período de gracia, sus pronósticos existentes se retirarán automáticamente.", + "newOptionsAddedPlural": "Estas opciones se añadieron recientemente, por favor ajuste su(s) pronóstico(s) en consecuencia.", + "newOptionsAddedSingular": "Se añadió una nueva opción recientemente, por favor ajuste sus pronósticos en consecuencia.", + "showNewOptions": "Mostrar nuevas opciones", + "showNewOption": "Mostrar nueva opción", + "timeRemaining": "Tiempo restante", "othersCount": "Otros ({count})" } diff --git a/front_end/messages/pt.json b/front_end/messages/pt.json index 99a3b22766..009248f733 100644 --- a/front_end/messages/pt.json +++ b/front_end/messages/pt.json @@ -1821,5 +1821,12 @@ "tournamentsInfoTitle": "Nós não somos um mercado de previsões. Você pode participar gratuitamente e ganhar prêmios em dinheiro por ser preciso.", "tournamentsInfoScoringLink": "O que são pontuações de previsão?", "tournamentsInfoPrizesLink": "Como os prêmios são distribuídos?", + "dismiss": "Dispensar", + "gracePeriodTooltip": "Se você não atualizar suas previsões antes do término do período de carência, suas previsões existentes serão retiradas automaticamente.", + "newOptionsAddedPlural": "Essas opções foram adicionadas recentemente, por favor ajuste suas previsões de acordo.", + "newOptionsAddedSingular": "Uma nova opção foi adicionada recentemente, por favor ajuste suas previsões de acordo.", + "showNewOptions": "Mostrar Novas Opções", + "showNewOption": "Mostrar Nova Opção", + "timeRemaining": "Tempo restante", "othersCount": "Outros ({count})" } diff --git a/front_end/messages/zh-TW.json b/front_end/messages/zh-TW.json index 8ba88fc0df..884d477d3e 100644 --- a/front_end/messages/zh-TW.json +++ b/front_end/messages/zh-TW.json @@ -1820,5 +1820,12 @@ "tournamentsInfoTitle": "我們 不是預測市場。您可以免費參加並因精確的預測贏取現金獎勵。", "tournamentsInfoScoringLink": "什麼是預測得分?", "tournamentsInfoPrizesLink": "獎品如何分配?", + "dismiss": "關閉", + "gracePeriodTooltip": "如果您在寬限期間結束之前未更新您的預測,您的現有預測將自動撤回。", + "newOptionsAddedPlural": "這些選項最近新增,請相應調整您的預測。", + "newOptionsAddedSingular": "一個新選項最近新增,請相應調整您的預測。", + "showNewOptions": "顯示新選項", + "showNewOption": "顯示新選項", + "timeRemaining": "剩餘時間", "withdrawAfterPercentSetting2": "問題總生命周期後撤回" } diff --git a/front_end/messages/zh.json b/front_end/messages/zh.json index 8b579b3602..290306bd16 100644 --- a/front_end/messages/zh.json +++ b/front_end/messages/zh.json @@ -1825,5 +1825,12 @@ "tournamentsInfoTitle": "我们不是一个预测市场。您可以免费参与,并因精准的预测赢得现金奖品。", "tournamentsInfoScoringLink": "什么是预测分数?", "tournamentsInfoPrizesLink": "奖品如何分配?", + "dismiss": "忽略", + "gracePeriodTooltip": "如果您在宽限期结束前没有更新您的预测,现有的预测将自动撤回。", + "newOptionsAddedPlural": "这些选项是最近添加的,请相应调整您的预测。", + "newOptionsAddedSingular": "最近添加了一个新选项,请相应调整您的预测。", + "showNewOptions": "显示新选项", + "showNewOption": "显示新选项", + "timeRemaining": "剩余时间", "othersCount": "其他({count})" } diff --git a/front_end/src/app/(main)/aggregation-explorer/components/explorer.tsx b/front_end/src/app/(main)/aggregation-explorer/components/explorer.tsx index 5d20a5c6f7..ddd8b23b38 100644 --- a/front_end/src/app/(main)/aggregation-explorer/components/explorer.tsx +++ b/front_end/src/app/(main)/aggregation-explorer/components/explorer.tsx @@ -23,7 +23,10 @@ import { SearchParams } from "@/types/navigation"; import { Post, PostWithForecasts } from "@/types/post"; import { QuestionType, QuestionWithForecasts } from "@/types/question"; import { logError } from "@/utils/core/errors"; -import { parseQuestionId } from "@/utils/questions/helpers"; +import { + getAllOptionsHistory, + parseQuestionId, +} from "@/utils/questions/helpers"; import { AggregationWrapper } from "./aggregation_wrapper"; import { AggregationExtraMethod } from "../types"; @@ -417,8 +420,9 @@ function parseSubQuestions( }, ]; } else if (data.question?.type === QuestionType.MultipleChoice) { + const allOptions = getAllOptionsHistory(data.question); return ( - data.question.options?.map((option) => ({ + allOptions?.map((option) => ({ value: option, label: option, })) || [] diff --git a/front_end/src/app/(main)/questions/[id]/components/question_view/forecaster_question_view/question_header/question_header_cp_status.tsx b/front_end/src/app/(main)/questions/[id]/components/question_view/forecaster_question_view/question_header/question_header_cp_status.tsx index b46f804ed4..1d630c2bd2 100644 --- a/front_end/src/app/(main)/questions/[id]/components/question_view/forecaster_question_view/question_header/question_header_cp_status.tsx +++ b/front_end/src/app/(main)/questions/[id]/components/question_view/forecaster_question_view/question_header/question_header_cp_status.tsx @@ -32,14 +32,16 @@ const QuestionHeaderCPStatus: FC = ({ const t = useTranslations(); const { hideCP } = useHideCP(); const forecastAvailability = getQuestionForecastAvailability(question); - const continuousAreaChartData = getContinuousAreaChartData({ - question, - isClosed: question.status === QuestionStatus.CLOSED, - }); const isContinuous = question.type === QuestionType.Numeric || question.type === QuestionType.Discrete || question.type === QuestionType.Date; + const continuousAreaChartData = !isContinuous + ? null + : getContinuousAreaChartData({ + question, + isClosed: question.status === QuestionStatus.CLOSED, + }); if (question.status === QuestionStatus.RESOLVED && question.resolution) { // Resolved/Annulled/Ambiguous diff --git a/front_end/src/components/charts/minified_continuous_area_chart.tsx b/front_end/src/components/charts/minified_continuous_area_chart.tsx index 4fce59f5e1..111d266603 100644 --- a/front_end/src/components/charts/minified_continuous_area_chart.tsx +++ b/front_end/src/components/charts/minified_continuous_area_chart.tsx @@ -56,7 +56,7 @@ const HORIZONTAL_PADDING = 10; type Props = { question: Question | GraphingQuestionProps; - data: ContinuousAreaGraphInput; + data: ContinuousAreaGraphInput | null; height?: number; width?: number; extraTheme?: VictoryThemeDefinition; @@ -81,6 +81,9 @@ const MinifiedContinuousAreaChart: FC = ({ forceTickCount, variant = "feed", }) => { + if (data === null) { + throw new Error("Data for MinifiedContinuousAreaChart is null"); + } const { ref: chartContainerRef, width: containerWidth } = useContainerSize(); const chartWidth = width || containerWidth; diff --git a/front_end/src/components/forecast_maker/forecast_choice_option.tsx b/front_end/src/components/forecast_maker/forecast_choice_option.tsx index 46236c978a..f96a39fe83 100644 --- a/front_end/src/components/forecast_maker/forecast_choice_option.tsx +++ b/front_end/src/components/forecast_maker/forecast_choice_option.tsx @@ -28,6 +28,14 @@ import { getForecastPctDisplayValue } from "@/utils/formatters/prediction"; import ForecastTextInput from "./forecast_text_input"; import Tooltip from "../ui/tooltip"; +// ============================================ +// ANIMATION & OPACITY SETTINGS - ADJUST HERE +// ============================================ +const GRADIENT_OPACITY_NORMAL = "1A"; // Normal state: ~10% (hex) +const GRADIENT_OPACITY_HOVER = "2D"; // Hover state: ~18% (hex) +const BORDER_WIDTH = "4px"; // Border width when animating +export const ANIMATION_DURATION_MS = 1500; // Total animation duration in milliseconds + type OptionResolution = { resolution: Resolution | null; type: "question" | "group_question"; @@ -52,6 +60,11 @@ type Props = { onOptionClick?: (id: T) => void; withdrawn?: boolean; withdrawnEndTimeSec?: number | null; + isNewOption?: boolean; + showHighlight?: boolean; + isAnimating?: boolean; + onInteraction?: () => void; + rowRef?: React.RefObject; }; const ForecastChoiceOption = ({ @@ -73,9 +86,20 @@ const ForecastChoiceOption = ({ onOptionClick, withdrawn = false, withdrawnEndTimeSec = null, + isNewOption = false, + showHighlight = false, + isAnimating = false, + onInteraction, + rowRef, }: Props) => { const t = useTranslations(); const locale = useLocale(); + const [isHovered, setIsHovered] = useState(false); + const [mounted, setMounted] = useState(false); + + useEffect(() => { + setMounted(true); + }, []); const inputDisplayValue = withdrawn && !isDirty @@ -124,8 +148,9 @@ const ForecastChoiceOption = ({ const handleSliderForecastChange = useCallback( (value: number) => { onChange(id, value); + onInteraction?.(); }, - [id, onChange] + [id, onChange, onInteraction] ); const handleInputChange = useCallback((value: string) => { setInputValue(value); @@ -181,16 +206,35 @@ const ForecastChoiceOption = ({ ); + const gradientColor = getThemeColor(choiceColor); + return ( <> onOptionClick?.(id)} + onMouseEnter={() => setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + style={{ + ...(mounted && + showHighlight && { + backgroundImage: `linear-gradient(to right, ${gradientColor}${isHovered ? GRADIENT_OPACITY_HOVER : GRADIENT_OPACITY_NORMAL} 0%, transparent 100%)`, + }), + ...(mounted && + isNewOption && { + outline: isAnimating + ? `${BORDER_WIDTH} solid ${gradientColor}` + : "0px solid transparent", + outlineOffset: "-4px", + transition: `outline ${ANIMATION_DURATION_MS * 0.2}ms ease-in-out`, + }), + }} >
@@ -228,6 +272,7 @@ const ForecastChoiceOption = ({ onFocus={() => { setIsInputFocused(true); onChange(id, defaultSliderValue); + onInteraction?.(); }} onBlur={() => setIsInputFocused(false)} disabled={disabled} @@ -249,7 +294,10 @@ const ForecastChoiceOption = ({ minValue={inputMin} maxValue={inputMax} value={inputValue} - onFocus={() => setIsInputFocused(true)} + onFocus={() => { + setIsInputFocused(true); + onInteraction?.(); + }} onBlur={() => setIsInputFocused(false)} disabled={disabled} /> @@ -263,10 +311,22 @@ const ForecastChoiceOption = ({ onOptionClick?.(id)} + onMouseEnter={() => setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + style={ + mounted && showHighlight + ? { + backgroundImage: `linear-gradient(to right, ${gradientColor}${isHovered ? GRADIENT_OPACITY_HOVER : GRADIENT_OPACITY_NORMAL} 0%, transparent 100%)`, + } + : undefined + } > ; + mounted: boolean; + getThemeColor: (color: ThemeColor) => string; + gracePeriodEnd: Date | null; + onShowNewOptions: () => void; + onDismiss: () => void; +}; + +/** + * Hook to display remaining time until grace period ends + * Updates every second for a live countdown + */ +const useGracePeriodCountdown = (gracePeriodEnd: Date | null) => { + const [timeRemaining, setTimeRemaining] = useState(""); + + useEffect(() => { + if (!gracePeriodEnd) { + setTimeRemaining(""); + return; + } + + const updateTime = () => { + const now = new Date(); + const diff = gracePeriodEnd.getTime() - now.getTime(); + + if (diff <= 0) { + setTimeRemaining("expired"); + return; + } + + const days = Math.floor(diff / (1000 * 60 * 60 * 24)); + const hours = Math.floor( + (diff % (1000 * 60 * 60 * 24)) / (1000 * 60 * 60) + ); + const minutes = Math.floor((diff % (1000 * 60 * 60)) / (1000 * 60)); + const seconds = Math.floor((diff % (1000 * 60)) / 1000); + + const pluralize = (count: number, singular: string) => + count === 1 ? singular : `${singular}s`; + + if (days > 0) { + setTimeRemaining( + `${days} ${pluralize(days, "day")}, ${hours} ${pluralize(hours, "hour")}` + ); + } else if (hours > 0) { + setTimeRemaining( + `${hours} ${pluralize(hours, "hour")}, ${minutes} ${pluralize(minutes, "minute")}` + ); + } else if (minutes > 0) { + setTimeRemaining( + `${minutes} ${pluralize(minutes, "minute")}, ${seconds} ${pluralize(seconds, "second")}` + ); + } else { + setTimeRemaining(`${seconds} ${pluralize(seconds, "second")}`); + } + }; + + updateTime(); + const interval = setInterval(updateTime, 1000); + + return () => clearInterval(interval); + }, [gracePeriodEnd]); + + return timeRemaining; +}; + +const NewOptionCallout: FC = ({ + newOptions, + mounted, + getThemeColor, + gracePeriodEnd, + onShowNewOptions, + onDismiss, +}) => { + const t = useTranslations(); + const isPlural = newOptions.length > 1; + const timeRemaining = useGracePeriodCountdown(gracePeriodEnd); + + return ( +
+
+

+ {isPlural ? t("newOptionsAddedPlural") : t("newOptionsAddedSingular")} +

+ {timeRemaining && timeRemaining !== "expired" && ( + +
+ + {t("timeRemaining")}: + + + {timeRemaining} + +
+
+ )} +
+ {isPlural && newOptions.length > 0 && mounted && ( +
+ {newOptions.map((option) => ( +
+
+ + {option.name} + +
+ ))} +
+ )} +
+ + +
+
+ ); +}; + type Props = { post: PostWithForecasts; question: QuestionWithMultipleChoiceForecasts; @@ -72,6 +219,14 @@ const ForecastMakerMultipleChoice: FC = ({ const t = useTranslations(); const { user } = useAuth(); const { hideCP } = useHideCP(); + const { getThemeColor } = useAppTheme(); + const [mounted, setMounted] = useState(false); + + useEffect(() => { + setMounted(true); + }, []); + + const allOptions = getAllOptionsHistory(question); const activeUserForecast = question.my_forecasts?.latest && @@ -112,7 +267,7 @@ const ForecastMakerMultipleChoice: FC = ({ }, [question, user?.prediction_expiration_percent]); // Set default expiration if not already set - React.useEffect(() => { + useEffect(() => { if (!modalSavedState.forecastExpiration) { setModalSavedState((prev) => ({ ...prev, @@ -127,6 +282,11 @@ const ForecastMakerMultipleChoice: FC = ({ const [isDirty, setIsDirty] = useState(false); const [isWithdrawModalOpen, setIsWithdrawModalOpen] = useState(false); + const [dismissedOverlay, setDismissedOverlay] = useState(false); + const [interactedOptions, setInteractedOptions] = useState>( + new Set() + ); + const [isAnimatingHighlight, setIsAnimatingHighlight] = useState(false); const [choicesForecasts, setChoicesForecasts] = useState( generateChoiceOptions( question, @@ -144,19 +304,95 @@ const ForecastMakerMultipleChoice: FC = ({ [choicesForecasts] ); const forecastsSum = useMemo( - () => (forecastHasValues ? sumForecasts(choicesForecasts) : null), - [choicesForecasts, forecastHasValues] + () => + forecastHasValues + ? sumForecasts( + choicesForecasts.filter((choice) => + question.options.includes(choice.name) + ) + ) + : null, + [question.options, choicesForecasts, forecastHasValues] ); const remainingSum = forecastsSum ? 100 - forecastsSum : null; const isForecastValid = forecastHasValues && forecastsSum === 100; const [submitError, setSubmitError] = useState(); + const showUserMustForecast = + !!activeUserForecast && + activeUserForecast.forecast_values.filter((value) => value !== null) + .length < question.options.length; + + const getNewOptions = useCallback(() => { + if (!activeUserForecast) return []; + + return choicesForecasts + .filter((choice, index) => { + const isCurrentOption = question.options.includes(choice.name); + const hasForecast = activeUserForecast.forecast_values[index] !== null; + return isCurrentOption && !hasForecast; + }) + .map((c) => ({ name: c.name, color: c.color })); + }, [activeUserForecast, choicesForecasts, question.options]); + + const newOptions = getNewOptions(); + const showOverlay = + showUserMustForecast && !dismissedOverlay && newOptions.length > 0; + + // Calculate grace period end time + const gracePeriodEnd = useMemo(() => { + try { + if (!question.options_history || question.options_history.length === 0) { + return null; + } + const history = question.options_history; + const lastEntry = history[history.length - 1]; + + if (!lastEntry || typeof lastEntry[0] === "undefined") { + return null; + } + + // Following coworker's implementation: new Date(history[history.length - 1][0]) + const gracePeriodEnd = new Date(lastEntry[0]); + + // Validate the date is valid + if (isNaN(gracePeriodEnd.getTime())) { + console.warn("Invalid grace period date:", lastEntry[0]); + return null; + } + + return gracePeriodEnd; + } catch (error) { + console.error("Error calculating grace period:", error); + return null; + } + }, [question.options_history]); + + const firstNewOptionRef = useRef(null); + + const scrollToNewOptions = () => { + if (firstNewOptionRef.current) { + // Trigger animation immediately + setIsAnimatingHighlight(true); + + firstNewOptionRef.current.scrollIntoView({ + behavior: "smooth", + block: "center", + }); + + // Reset animation after duration + setTimeout(() => { + setIsAnimatingHighlight(false); + }, ANIMATION_DURATION_MS); + } + }; + const resetForecasts = useCallback(() => { setIsDirty(false); setChoicesForecasts((prev) => - question.options.map((_, index) => { - // okay to do no-non-null-assertion, as choicesForecasts is mapped based on question.options + allOptions.map((_, index) => { + // okay to do no-non-null-assertion, as choicesForecasts is mapped based on allOptions // so there won't be a case where arrays are not of the same length // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const choiceOption = prev[index]!; @@ -171,7 +407,7 @@ const ForecastMakerMultipleChoice: FC = ({ }; }) ); - }, [question.options, question.my_forecasts?.latest?.forecast_values]); + }, [allOptions, question.my_forecasts?.latest?.forecast_values]); const handleForecastChange = useCallback( (choice: string, value: number) => { @@ -184,7 +420,7 @@ const ForecastMakerMultipleChoice: FC = ({ const isInitialChange = prev.some((el) => el.forecast === null); - if (isInitialChange) { + if (isInitialChange && prevChoice.forecast === null) { // User is predicting for the first time. Show default non-null values // for remaining options after first interaction with the inputs. return { ...prevChoice, forecast: equalizedForecast }; @@ -208,6 +444,9 @@ const ForecastMakerMultipleChoice: FC = ({ if (isNil(choice.forecast) || isNil(forecastsSum)) { return null; } + if (!question.options.includes(choice.name)) { + return 0.0; + } const value = round( Math.max(round((100 * choice.forecast) / forecastsSum, 1), 0.1), @@ -249,6 +488,9 @@ const ForecastMakerMultipleChoice: FC = ({ const forecastValue: Record = {}; choicesForecasts.forEach((el) => { + if (!question.options.includes(el.name)) { + return; // only submit forecasts for current options + } const forecast = el.forecast; if (!isNil(forecast)) { forecastValue[el.name] = round( @@ -329,6 +571,16 @@ const ForecastMakerMultipleChoice: FC = ({ isSubmissionDisabled={!isForecastValid} onSubmit={submit} /> + {showOverlay && ( + setDismissedOverlay(true)} + /> + )} @@ -354,28 +606,50 @@ const ForecastMakerMultipleChoice: FC = ({ - {choicesForecasts.map((choice) => ( - - ))} + {choicesForecasts.map((choice) => { + if (question.options.includes(choice.name)) { + const isFirstNewOption = + newOptions.length > 0 && choice.name === newOptions[0]?.name; + const isNewOption = newOptions.some( + (opt) => opt.name === choice.name + ); + return ( + { + isNewOption + ? setInteractedOptions((prev) => + new Set(prev).add(choice.name) + ) + : undefined; + }} + rowRef={isFirstNewOption ? firstNewOptionRef : undefined} + /> + ); + } + })}
{predictionMessage && ( @@ -485,8 +759,9 @@ function generateChoiceOptions( userLastForecast: MultipleChoiceUserForecast | undefined ): ChoiceOption[] { const latest = aggregate.latest; + const allOptions = getAllOptionsHistory(question); - const choiceItems = question.options.map((option, index) => { + const choiceItems = allOptions.map((option, index) => { const communityForecastValue = latest?.forecast_values[index]; const userForecastValue = userLastForecast?.forecast_values[index]; @@ -502,8 +777,8 @@ function generateChoiceOptions( : null, }; }); - const resolutionIndex = question.options.findIndex( - (_, index) => question.options[index] === question.resolution + const resolutionIndex = allOptions.findIndex( + (_, index) => allOptions[index] === question.resolution ); if (resolutionIndex !== -1) { const [resolutionItem] = choiceItems.splice(resolutionIndex, 1); diff --git a/front_end/src/types/question.ts b/front_end/src/types/question.ts index 8c6c2133d0..430d5127f3 100644 --- a/front_end/src/types/question.ts +++ b/front_end/src/types/question.ts @@ -280,6 +280,7 @@ export type Question = { type: QuestionType; // Multiple-choice only options?: string[]; + options_history?: [number, string[]][]; group_variable?: string; group_rank?: number; // Continuous only diff --git a/front_end/src/utils/questions/choices.ts b/front_end/src/utils/questions/choices.ts index d1c818a442..985c9501f7 100644 --- a/front_end/src/utils/questions/choices.ts +++ b/front_end/src/utils/questions/choices.ts @@ -14,6 +14,7 @@ import { formatResolution, } from "@/utils/formatters/resolution"; import { sortGroupPredictionOptions } from "@/utils/questions/groupOrdering"; +import { getAllOptionsHistory } from "@/utils/questions/helpers"; import { isUnsuccessfullyResolved } from "@/utils/questions/resolution"; export function generateChoiceItemsFromMultipleChoiceForecast( @@ -32,7 +33,8 @@ export function generateChoiceItemsFromMultipleChoiceForecast( const latest = question.aggregations[question.default_aggregation_method].latest; - const choiceOrdering: number[] = question.options?.map((_, i) => i) ?? []; + const allOptions = getAllOptionsHistory(question); + const choiceOrdering: number[] = allOptions?.map((_, i) => i) ?? []; if (!preserveOrder) { choiceOrdering.sort((a, b) => { const aCenter = latest?.forecast_values[a] ?? 0; @@ -41,7 +43,7 @@ export function generateChoiceItemsFromMultipleChoiceForecast( }); } - const labels = question.options ? question.options : []; + const labels = allOptions ? allOptions : []; const aggregationHistory = question.aggregations[question.default_aggregation_method].history; const userHistory = question.my_forecasts?.history; @@ -139,7 +141,7 @@ export function generateChoiceItemsFromMultipleChoiceForecast( const orderedChoiceItems = choiceOrdering.map((order) => choiceItems[order]); // move resolved choice to the front const resolutionIndex = choiceOrdering.findIndex( - (order) => question.options?.[order] === question.resolution + (order) => allOptions?.[order] === question.resolution ); if (resolutionIndex !== -1) { const [resolutionItem] = orderedChoiceItems.splice(resolutionIndex, 1); diff --git a/front_end/src/utils/questions/helpers.ts b/front_end/src/utils/questions/helpers.ts index f5f8dbe96b..e60d2b1da9 100644 --- a/front_end/src/utils/questions/helpers.ts +++ b/front_end/src/utils/questions/helpers.ts @@ -199,3 +199,19 @@ export function inferEffectiveQuestionTypeFromPost( return null; } + +export function getAllOptionsHistory(question: Question): string[] { + const allOptions: string[] = []; + (question.options_history ?? []).map((entry) => { + entry[1].slice(0, -1).map((option) => { + if (!allOptions.includes(option)) { + allOptions.push(option); + } + }); + }); + const other = (question.options ?? []).at(-1); + if (other) { + allOptions.push(other); + } + return allOptions; +} diff --git a/misc/views.py b/misc/views.py index 611da007bc..519036e88d 100644 --- a/misc/views.py +++ b/misc/views.py @@ -129,7 +129,9 @@ def get_site_stats(request): now_year = datetime.now().year public_questions = Question.objects.filter_public() stats = { - "predictions": Forecast.objects.filter(question__in=public_questions).count(), + "predictions": Forecast.objects.filter(question__in=public_questions) + .exclude(source=Forecast.SourceChoices.AUTOMATIC) + .count(), "questions": public_questions.count(), "resolved_questions": public_questions.filter(actual_resolve_time__isnull=False) .exclude(resolution__in=UnsuccessfulResolutionType) diff --git a/notifications/templates/emails/multiple_choice_option_addition.html b/notifications/templates/emails/multiple_choice_option_addition.html new file mode 100644 index 0000000000..d3ca2e1493 --- /dev/null +++ b/notifications/templates/emails/multiple_choice_option_addition.html @@ -0,0 +1,364 @@ + +{% load i18n %} {% load static %} {% load urls %} {% load utils %} + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + + + +
+ +
+ + + + + + + + + +
+ + + + + + +
+ Metaculus Logo +
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
Hello {{recipient.username}},
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
+
{% blocktrans %} New options were added to this multiple choice question. {% endblocktrans %}
+ {% if params.added_options %}
{% blocktrans with count=params.added_options|length %} Added option{{ count|pluralize }}: {% endblocktrans %}
+
    {% for option in params.added_options %}
  • {{ option }}
  • {% endfor %}
{% endif %}
{% blocktrans with grace_period_end=params.grace_period_end %} Please update your forecast by {{ grace_period_end }}. If not, it will be automatically withdrawn at that time. {% endblocktrans %}
+
{% blocktrans %} Learn more in our {% endblocktrans %} FAQ.
+
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+ + + + + + +
+ Update your forecast +
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
{% blocktrans %} Related questions you might find interesting {% endblocktrans %}
+
+
+ +
+
+ + + + + + + + + +
+ + + + diff --git a/notifications/templates/emails/multiple_choice_option_addition.mjml b/notifications/templates/emails/multiple_choice_option_addition.mjml new file mode 100644 index 0000000000..0ff6b53ad0 --- /dev/null +++ b/notifications/templates/emails/multiple_choice_option_addition.mjml @@ -0,0 +1,74 @@ + + + + + + + + + + Hello {{recipient.username}}, + + + + + + + +
+ {% blocktrans %} + New options were added to this multiple choice question. + {% endblocktrans %} +
+ + + {% if params.added_options %} +
+ {% blocktrans with count=params.added_options|length %} + Added option{{ count|pluralize }}: + {% endblocktrans %} +
+
    + {% for option in params.added_options %} +
  • {{ option }}
  • + {% endfor %} +
+ {% endif %} + +
+ {% blocktrans with grace_period_end=params.grace_period_end %} + Please update your forecast by {{ grace_period_end }}. If not, + it will be automatically withdrawn at that time. + {% endblocktrans %} +
+
+ {% blocktrans %} + Learn more in our + {% endblocktrans %} + FAQ. +
+
+
+
+ + + + + Update your forecast + + + + + + +
+
diff --git a/notifications/templates/emails/multiple_choice_option_deletion.html b/notifications/templates/emails/multiple_choice_option_deletion.html new file mode 100644 index 0000000000..1949f457e7 --- /dev/null +++ b/notifications/templates/emails/multiple_choice_option_deletion.html @@ -0,0 +1,363 @@ + +{% load i18n %} {% load static %} {% load urls %} {% load utils %} + + + + + + + + + + + + + + + + + + + + +
+ +
+ + + + + + +
+ +
+ + + + + + + + + +
+ + + + + + +
+ Metaculus Logo +
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
Hello {{recipient.username}},
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
+
{% blocktrans %} Options were removed from this multiple choice question. {% endblocktrans %}
+ {% if params.removed_options %}
{% blocktrans with count=params.removed_options|length %} Removed option{{ count|pluralize }}: {% endblocktrans %}
+
    {% for option in params.removed_options %}
  • {{ option }}
  • {% endfor %}
{% endif %}
{% blocktrans with timestep=params.timestep %} This change took effect at {{ timestep }}. Any probability on removed options was moved to the catch-all option. {% endblocktrans %}
+
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+ + + + + + +
+ Review the question +
+
+
+ +
+
+ +
+ + + + + + +
+ +
+ + + + + + +
+
{% blocktrans %} Related questions you might find interesting {% endblocktrans %}
+
+
+ +
+
+ + + + + + + + + +
+ + + + diff --git a/notifications/templates/emails/multiple_choice_option_deletion.mjml b/notifications/templates/emails/multiple_choice_option_deletion.mjml new file mode 100644 index 0000000000..de4dc3486c --- /dev/null +++ b/notifications/templates/emails/multiple_choice_option_deletion.mjml @@ -0,0 +1,68 @@ + + + + + + + + + + Hello {{recipient.username}}, + + + + + + + +
+ {% blocktrans %} + Options were removed from this multiple choice question. + {% endblocktrans %} +
+ + + {% if params.removed_options %} +
+ {% blocktrans with count=params.removed_options|length %} + Removed option{{ count|pluralize }}: + {% endblocktrans %} +
+
    + {% for option in params.removed_options %} +
  • {{ option }}
  • + {% endfor %} +
+ {% endif %} + +
+ {% blocktrans with timestep=params.timestep %} + This change took effect at {{ timestep }}. Any probability on removed options + was moved to the catch-all option. + {% endblocktrans %} +
+
+
+
+ + + + + Review the question + + + + + + +
+
diff --git a/posts/models.py b/posts/models.py index 275f08a701..757524403f 100644 --- a/posts/models.py +++ b/posts/models.py @@ -813,7 +813,11 @@ def update_forecasts_count(self): Update forecasts count cache """ - self.forecasts_count = self.forecasts.filter_within_question_period().count() + self.forecasts_count = ( + self.forecasts.filter_within_question_period() + .exclude(source=Forecast.SourceChoices.AUTOMATIC) + .count() + ) self.save(update_fields=["forecasts_count"]) def update_forecasters_count(self): diff --git a/questions/admin.py b/questions/admin.py index 66e662f750..afd8183733 100644 --- a/questions/admin.py +++ b/questions/admin.py @@ -1,10 +1,17 @@ from admin_auto_filters.filters import AutocompleteFilterFactory -from django.contrib import admin +from datetime import datetime, timedelta + +from django import forms +from django.contrib import admin, messages +from django.core.exceptions import PermissionDenied from django.db.models import QuerySet -from django.http import HttpResponse -from django.urls import reverse +from django.http import Http404, HttpResponse, HttpResponseRedirect +from django.template.response import TemplateResponse +from django.urls import path, reverse +from django.utils import timezone from django.utils.html import format_html from django_better_admin_arrayfield.admin.mixins import DynamicArrayMixin +from rest_framework.exceptions import ValidationError as DRFValidationError from posts.models import Post from posts.tasks import run_post_generate_history_snapshot @@ -18,10 +25,402 @@ ) from questions.services.forecasts import build_question_forecasts from questions.types import AggregationMethod +from questions.services.multiple_choice_handlers import ( + MultipleChoiceOptionsUpdateSerializer, + get_all_options_from_history, + multiple_choice_add_options, + multiple_choice_change_grace_period_end, + multiple_choice_delete_options, + multiple_choice_rename_option, + multiple_choice_reorder_options, +) from utils.csv_utils import export_all_data_for_questions from utils.models import CustomTranslationAdmin +def get_latest_options_history_datetime(options_history): + if not options_history: + return None + raw_timestamp = options_history[-1][0] + try: + if isinstance(raw_timestamp, datetime): + parsed_timestamp = raw_timestamp + elif isinstance(raw_timestamp, str): + parsed_timestamp = datetime.fromisoformat(raw_timestamp) + else: + return None + except ValueError: + return None + if timezone.is_naive(parsed_timestamp): + parsed_timestamp = timezone.make_aware(parsed_timestamp) + return parsed_timestamp + + +def has_active_grace_period(options_history, reference_time=None): + reference_time = reference_time or timezone.now() + latest_timestamp = get_latest_options_history_datetime(options_history) + return bool(latest_timestamp and latest_timestamp > reference_time) + + +class MultipleChoiceOptionsAdminForm(forms.Form): + ACTION_RENAME = "rename_options" + ACTION_DELETE = "delete_options" + ACTION_ADD = "add_options" + ACTION_CHANGE_GRACE = "change_grace_period_end" # not ready yet + ACTION_REORDER = "reorder_options" + ACTION_CHOICES = ( + (ACTION_RENAME, "Rename options"), + (ACTION_DELETE, "Delete options"), + (ACTION_ADD, "Add options"), + # (ACTION_CHANGE_GRACE, "Change grace period end"), + (ACTION_REORDER, "Reorder options"), + ) + + action = forms.ChoiceField(choices=ACTION_CHOICES, required=True) + old_option = forms.ChoiceField(required=False) + new_option = forms.CharField( + required=False, label="New option text", strip=True, max_length=200 + ) + options_to_delete = forms.MultipleChoiceField( + required=False, widget=forms.CheckboxSelectMultiple + ) + new_options = forms.CharField( + required=False, + help_text="Comma-separated options to add before the catch-all option.", + ) + grace_period_end = forms.DateTimeField( + required=False, + help_text=( + "Default value is 2 weeks from now. " + "Required when adding options; must be in the future. " + "Format: YYYY-MM-DD or YYYY-MM-DD HH:MM (time optional)." + ), + input_formats=["%Y-%m-%dT%H:%M", "%Y-%m-%d %H:%M", "%Y-%m-%d"], + ) + delete_comment = forms.CharField( + required=False, + label="Delete options comment", + widget=forms.Textarea(attrs={"rows": 3}), + help_text="Placeholders will auto-fill; edit as needed." + " {removed_options} becomes ['a', 'b'], {timestep} is the time of " + "deletion in isoformat.", + ) + add_comment = forms.CharField( + required=False, + label="Add options comment", + widget=forms.Textarea(attrs={"rows": 4}), + help_text="Placeholders will auto-fill; edit as needed." + " {removed_options} becomes ['a', 'b'], {timestep} is the time of " + "deletion in isoformat.", + ) + + def __init__(self, question: Question, *args, **kwargs): + self.question = question + super().__init__(*args, **kwargs) + + options_history = question.options_history or [] + self.options_grace_period_end = get_latest_options_history_datetime( + options_history + ) + default_delete_comment = ( + "Options {removed_options} were removed at {timestep}. " + "Forecasts were adjusted to keep remaining probability on the catch-all." + ) + default_add_comment = ( + "Options {added_options} were added at {timestep}. " + "Please update forecasts before {grace_period_end}; " + "forecasts that are not updated will auto-withdraw then." + ) + + active_grace = has_active_grace_period(options_history) + action_choices = list(self.ACTION_CHOICES) + if active_grace: + action_choices = [ + choice + for choice in action_choices + if choice[0] in (self.ACTION_RENAME, self.ACTION_CHANGE_GRACE) + ] + else: + action_choices = [ + choice + for choice in action_choices + if choice[0] != self.ACTION_CHANGE_GRACE + ] + if len(options_history) > 1: + action_choices = [ + choice for choice in action_choices if choice[0] != self.ACTION_REORDER + ] + action = forms.ChoiceField( + choices=[("", "Select action")] + action_choices, + required=True, + initial="", + ) + self.fields["action"] = action + all_options = ( + get_all_options_from_history(options_history) if options_history else [] + ) + self.fields["old_option"].choices = [(opt, opt) for opt in all_options] + + current_options = question.options or [] + self.fields["options_to_delete"].choices = [ + (opt, opt) for opt in current_options + ] + self.reorder_field_names: list[tuple[str, str]] = [] + for index, option in enumerate(current_options): + field_name = f"reorder_position_{index}" + self.reorder_field_names.append((option, field_name)) + self.fields[field_name] = forms.IntegerField( + required=False, + min_value=1, + label=f"Order for '{option}'", + help_text="Use integers; options will be ordered ascending.", + ) + if current_options: + self.fields["options_to_delete"].widget.attrs["data-catch-all"] = ( + current_options[-1] + ) + self.fields["options_to_delete"].help_text = ( + "Warning: do not remove all options. The question should have at least " + "2 options: the last option you can't delete, and one other." + ) + grace_field = self.fields["grace_period_end"] + grace_field.widget = forms.DateTimeInput(attrs={"type": "datetime-local"}) + grace_initial = self.options_grace_period_end or ( + timezone.now() + timedelta(days=14) + ) + if grace_initial and timezone.is_naive(grace_initial): + grace_initial = timezone.make_aware(grace_initial) + grace_field.initial = timezone.localtime(grace_initial).strftime( + "%Y-%m-%dT%H:%M" + ) + if self.options_grace_period_end: + grace_field.help_text = ( + f"Current grace period end: " + f"{timezone.localtime(self.options_grace_period_end)}. " + "Provide a new end to extend or shorten." + ) + self.fields["delete_comment"].initial = default_delete_comment + self.fields["add_comment"].initial = default_add_comment + + def is_in_grace_period(self, reference_time=None): + reference_time = reference_time or timezone.now() + return bool( + self.options_grace_period_end + and self.options_grace_period_end > reference_time + ) + + def clean(self): + cleaned_data = super().clean() + question = self.question + action = cleaned_data.get("action") + current_options = question.options or [] + options_history = question.options_history or [] + now = timezone.now() + + if not question.options or not question.options_history: + raise forms.ValidationError( + "This question needs options and an options history to update." + ) + + if not action: + return cleaned_data + + if action == self.ACTION_RENAME: + old_option = cleaned_data.get("old_option") + new_option = cleaned_data.get("new_option", "") + + if not old_option: + self.add_error("old_option", "Select an option to rename.") + if not new_option or not new_option.strip(): + self.add_error("new_option", "Enter the new option text.") + new_option = (new_option or "").strip() + + if self.errors: + return cleaned_data + + if old_option not in current_options: + self.add_error( + "old_option", "Selected option is not part of the current choices." + ) + return cleaned_data + + new_options = [ + new_option if opt == old_option else opt for opt in current_options + ] + if len(set(new_options)) != len(new_options): + self.add_error( + "new_option", "New option duplicates an existing option." + ) + return cleaned_data + + cleaned_data["target_option"] = old_option + cleaned_data["parsed_new_option"] = new_option + return cleaned_data + + if action == self.ACTION_DELETE: + options_to_delete = cleaned_data.get("options_to_delete") or [] + catch_all_option = current_options[-1] if current_options else None + if not options_to_delete: + self.add_error( + "options_to_delete", "Select at least one option to delete." + ) + return cleaned_data + if catch_all_option and catch_all_option in options_to_delete: + self.add_error( + "options_to_delete", "The final catch-all option cannot be deleted." + ) + + new_options = [ + opt for opt in current_options if opt not in options_to_delete + ] + if len(new_options) < 2: + self.add_error( + "options_to_delete", + "At least one option in addition to the catch-all must remain.", + ) + if self.is_in_grace_period(now): + self.add_error( + "options_to_delete", + "Options cannot change during an active grace period.", + ) + + if self.errors: + return cleaned_data + + serializer = MultipleChoiceOptionsUpdateSerializer( + context={"question": question} + ) + try: + serializer.validate_new_options(new_options, options_history, None) + except DRFValidationError as exc: + raise forms.ValidationError(exc.detail or exc.args) + + cleaned_data["options_to_delete"] = options_to_delete + cleaned_data["delete_comment"] = cleaned_data.get("delete_comment", "") + return cleaned_data + + if action == self.ACTION_ADD: + new_options_raw = cleaned_data.get("new_options") or "" + grace_period_end = cleaned_data.get("grace_period_end") + if grace_period_end and timezone.is_naive(grace_period_end): + grace_period_end = timezone.make_aware(grace_period_end) + cleaned_data["grace_period_end"] = grace_period_end + new_options_list = [ + opt.strip() for opt in new_options_raw.split(",") if opt.strip() + ] + if not new_options_list: + self.add_error("new_options", "Enter at least one option to add.") + if len(new_options_list) != len(set(new_options_list)): + self.add_error("new_options", "New options list includes duplicates.") + + duplicate_existing = set(current_options).intersection(new_options_list) + if duplicate_existing: + self.add_error( + "new_options", + f"Options already exist: {', '.join(sorted(duplicate_existing))}", + ) + + if not grace_period_end: + self.add_error( + "grace_period_end", "Grace period end is required when adding." + ) + elif grace_period_end <= now: + self.add_error( + "grace_period_end", "Grace period end must be in the future." + ) + if self.is_in_grace_period(now): + self.add_error( + "grace_period_end", + "Options cannot change during an active grace period.", + ) + + if self.errors: + return cleaned_data + + serializer = MultipleChoiceOptionsUpdateSerializer( + context={"question": question} + ) + new_options = current_options[:-1] + new_options_list + current_options[-1:] + try: + serializer.validate_new_options( + new_options, options_history, grace_period_end + ) + except DRFValidationError as exc: + raise forms.ValidationError(exc.detail or exc.args) + + cleaned_data["new_options_list"] = new_options_list + cleaned_data["grace_period_end"] = grace_period_end + cleaned_data["add_comment"] = cleaned_data.get("add_comment", "") + return cleaned_data + + if action == self.ACTION_CHANGE_GRACE: + new_grace_end = cleaned_data.get("grace_period_end") + if new_grace_end and timezone.is_naive(new_grace_end): + new_grace_end = timezone.make_aware(new_grace_end) + cleaned_data["grace_period_end"] = new_grace_end + + if not new_grace_end: + self.add_error( + "grace_period_end", "New grace period end is required to change it." + ) + elif new_grace_end <= now: + self.add_error( + "grace_period_end", "Grace period end must be in the future." + ) + + if not self.is_in_grace_period(now): + self.add_error( + "grace_period_end", + "There is no active grace period to change.", + ) + + if self.errors: + return cleaned_data + + cleaned_data["new_grace_period_end"] = new_grace_end + return cleaned_data + + if action == self.ACTION_REORDER: + if len(options_history) > 1: + self.add_error( + "action", + "Options can only be reordered when there is a single options history entry.", + ) + return cleaned_data + + positions: dict[str, int] = {} + seen_values: set[int] = set() + + for option, field_name in getattr(self, "reorder_field_names", []): + value = cleaned_data.get(field_name) + if value is None: + self.add_error(field_name, "Enter an order value.") + continue + if value in seen_values: + self.add_error( + field_name, + "Order value must be unique.", + ) + continue + seen_values.add(value) + positions[option] = value + + if self.errors: + return cleaned_data + + if len(positions) != len(current_options): + raise forms.ValidationError("Provide an order value for every option.") + + desired_order = [ + option + for option, _ in sorted(positions.items(), key=lambda item: item[1]) + ] + cleaned_data["new_order"] = desired_order + return cleaned_data + + raise forms.ValidationError("Invalid action selected.") + + @admin.register(Question) class QuestionAdmin(CustomTranslationAdmin, DynamicArrayMixin): list_display = [ @@ -33,7 +432,13 @@ class QuestionAdmin(CustomTranslationAdmin, DynamicArrayMixin): "curation_status", "post_link", ] - readonly_fields = ["post_link", "view_forecasts"] + readonly_fields = [ + "post_link", + "view_forecasts", + "options", + "options_history", + "update_mc_options", + ] search_fields = [ "id", "title_original", @@ -84,6 +489,22 @@ def view_forecasts(self, obj): url = reverse("admin:questions_forecast_changelist") + f"?question={obj.id}" return format_html('View Forecasts', url) + def update_mc_options(self, obj): + if not obj: + return "Save the question to manage options." + if obj.type != Question.QuestionType.MULTIPLE_CHOICE: + return "Option updates are available for multiple choice questions only." + if not obj.options_history or not obj.options: + return "Options and options history are required to update choices." + url = reverse("admin:questions_question_update_options", args=[obj.id]) + return format_html( + 'Update multiple choice options' + '

Rename, delete, or add options while keeping history.

', + url, + ) + + update_mc_options.short_description = "Multiple choice options" + def should_update_translations(self, obj): post = obj.get_post() is_private = post.default_project.default_permission is None @@ -91,12 +512,34 @@ def should_update_translations(self, obj): return not is_private and is_approved + def get_urls(self): + urls = super().get_urls() + custom_urls = [ + path( + "/update-options/", + self.admin_site.admin_view(self.update_options_view), + name="questions_question_update_options", + ), + ] + return custom_urls + urls + def get_fields(self, request, obj=None): fields = super().get_fields(request, obj) + + def insert_after(target_field: str, new_field: str): + if new_field in fields: + fields.remove(new_field) + if target_field in fields: + fields.insert(fields.index(target_field) + 1, new_field) + else: + fields.append(new_field) + for field in ["post_link", "view_forecasts"]: if field in fields: fields.remove(field) fields.insert(0, field) + if obj: + insert_after("options_history", "update_mc_options") return fields def save_model(self, request, obj, form, change): @@ -139,6 +582,122 @@ def export_selected_questions_data_anonymized( ): return self.export_selected_questions_data(request, queryset, anonymized=True) + def update_options_view(self, request, question_id: int): + question = Question.objects.filter(pk=question_id).first() + if not question: + raise Http404("Question not found.") + if not self.has_change_permission(request, question): + raise PermissionDenied + + change_url = reverse("admin:questions_question_change", args=[question.id]) + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + messages.error( + request, "Option updates are available for multiple choice questions." + ) + return HttpResponseRedirect(change_url) + if not question.options or not question.options_history: + messages.error( + request, + "Options and options history are required before updating choices.", + ) + return HttpResponseRedirect(change_url) + + form = MultipleChoiceOptionsAdminForm( + question, data=request.POST or None, prefix="options" + ) + if request.method == "POST" and form.is_valid(): + action = form.cleaned_data["action"] + if action == form.ACTION_RENAME: + old_option = form.cleaned_data["target_option"] + new_option = form.cleaned_data["parsed_new_option"] + multiple_choice_rename_option(question, old_option, new_option) + question.save(update_fields=["options", "options_history"]) + self.message_user( + request, f"Renamed option '{old_option}' to '{new_option}'." + ) + elif action == form.ACTION_REORDER: + new_order = form.cleaned_data["new_order"] + multiple_choice_reorder_options(question, new_order) + question.save(update_fields=["options", "options_history"]) + self.message_user( + request, + "Reordered options.", + ) + elif action == form.ACTION_DELETE: + options_to_delete = form.cleaned_data["options_to_delete"] + delete_comment = form.cleaned_data.get("delete_comment", "") + multiple_choice_delete_options( + question, + options_to_delete, + comment_author=request.user, + timestep=timezone.now(), + comment_text=delete_comment, + ) + question.save(update_fields=["options", "options_history"]) + self.message_user( + request, + f"Deleted {len(options_to_delete)} option" + f"{'' if len(options_to_delete) == 1 else 's'}.", + ) + elif action == form.ACTION_ADD: + new_options = form.cleaned_data["new_options_list"] + grace_period_end = form.cleaned_data["grace_period_end"] + add_comment = form.cleaned_data.get("add_comment", "") + if timezone.is_naive(grace_period_end): + grace_period_end = timezone.make_aware(grace_period_end) + multiple_choice_add_options( + question, + new_options, + grace_period_end=grace_period_end, + comment_author=request.user, + timestep=timezone.now(), + comment_text=add_comment, + ) + question.save(update_fields=["options", "options_history"]) + self.message_user( + request, + f"Added {len(new_options)} option" + f"{'' if len(new_options) == 1 else 's'}.", + ) + elif action == form.ACTION_CHANGE_GRACE: + new_grace_period_end = form.cleaned_data["new_grace_period_end"] + if timezone.is_naive(new_grace_period_end): + new_grace_period_end = timezone.make_aware(new_grace_period_end) + multiple_choice_change_grace_period_end( + question, + new_grace_period_end, + comment_author=request.user, + timestep=timezone.now(), + ) + question.save(update_fields=["options_history"]) + self.message_user( + request, + f"Grace period end updated to {timezone.localtime(new_grace_period_end)}.", + ) + return HttpResponseRedirect(change_url) + + grace_period_end = form.options_grace_period_end + in_grace_period = form.is_in_grace_period() + + context = { + **self.admin_site.each_context(request), + "opts": self.model._meta, + "app_label": self.model._meta.app_label, + "original": question, + "question": question, + "title": f"Update options for {question}", + "form": form, + "media": self.media + form.media, + "change_url": change_url, + "current_options": question.options or [], + "all_history_options": get_all_options_from_history( + question.options_history + ), + "grace_period_end": grace_period_end, + "in_grace_period": in_grace_period, + } + return TemplateResponse(request, "admin/questions/update_options.html", context) + def rebuild_aggregation_history(self, request, queryset: QuerySet[Question]): for question in queryset: build_question_forecasts(question) diff --git a/questions/migrations/0013_forecast_source.py b/questions/migrations/0013_forecast_source.py index ccd11208eb..4230d216bf 100644 --- a/questions/migrations/0013_forecast_source.py +++ b/questions/migrations/0013_forecast_source.py @@ -15,7 +15,7 @@ class Migration(migrations.Migration): name="source", field=models.CharField( blank=True, - choices=[("api", "Api"), ("ui", "Ui")], + choices=[("api", "Api"), ("ui", "Ui"), ("automatic", "Automatic")], default="", max_length=30, null=True, diff --git a/questions/migrations/0033_question_options_history.py b/questions/migrations/0033_question_options_history.py new file mode 100644 index 0000000000..7c4b69a97b --- /dev/null +++ b/questions/migrations/0033_question_options_history.py @@ -0,0 +1,50 @@ +# Generated by Django 5.1.13 on 2025-11-15 19:35 +from datetime import datetime + + +import questions.models +from django.db import migrations, models + + +def initialize_options_history(apps, schema_editor): + Question = apps.get_model("questions", "Question") + questions = Question.objects.filter(options__isnull=False) + for question in questions: + if question.options: + question.options_history = [(datetime.min.isoformat(), question.options)] + Question.objects.bulk_update(questions, ["options_history"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ("questions", "0032_alter_aggregateforecast_forecast_values_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="forecast", + name="source", + field=models.CharField( + blank=True, + choices=[("api", "Api"), ("ui", "Ui"), ("automatic", "Automatic")], + db_index=True, + default="", + max_length=30, + null=True, + ), + ), + migrations.AddField( + model_name="question", + name="options_history", + field=models.JSONField( + blank=True, + help_text="For Multiple Choice only.\n
list of tuples: (isoformat_datetime, options_list). (json stores them as lists)\n
Records the history of options over time.\n
Initialized with (datetime.min.isoformat(), self.options) upon question creation.\n
Updated whenever options are changed.", + null=True, + validators=[questions.models.validate_options_history], + ), + ), + migrations.RunPython( + initialize_options_history, reverse_code=migrations.RunPython.noop + ), + ] diff --git a/questions/models.py b/questions/models.py index 78d1252853..a7916777dc 100644 --- a/questions/models.py +++ b/questions/models.py @@ -9,7 +9,7 @@ from sql_util.aggregates import SubqueryAggregate from questions.constants import QuestionStatus -from questions.types import AggregationMethod +from questions.types import AggregationMethod, OptionsHistoryType from scoring.constants import ScoreTypes from users.models import User from utils.models import TimeStampedModel, TranslatedModel @@ -21,6 +21,27 @@ DEFAULT_INBOUND_OUTCOME_COUNT = 200 +def validate_options_history(value): + # Expect: [ (float, [str, ...]), ... ] or equivalent + if not isinstance(value, list): + raise ValidationError("Must be a list.") + for i, item in enumerate(value): + if ( + not isinstance(item, (list, tuple)) + or len(item) != 2 + or not isinstance(item[0], str) + or not isinstance(item[1], list) + or not all(isinstance(s, str) for s in item[1]) + ): + raise ValidationError(f"Bad item at index {i}: {item!r}") + try: + datetime.fromisoformat(item[0]) + except ValueError: + raise ValidationError( + f"Bad datetime format at index {i}: {item[0]!r}, must be isoformat string" + ) + + class QuestionQuerySet(QuerySet): def annotate_forecasts_count(self): return self.annotate( @@ -198,8 +219,20 @@ class QuestionType(models.TextChoices): ) unit = models.CharField(max_length=25, blank=True) - # list of multiple choice option labels - options = ArrayField(models.CharField(max_length=200), blank=True, null=True) + # multiple choice fields + options: list[str] | None = ArrayField( + models.CharField(max_length=200), blank=True, null=True + ) + options_history: OptionsHistoryType | None = models.JSONField( + null=True, + blank=True, + validators=[validate_options_history], + help_text="""For Multiple Choice only. +
list of tuples: (isoformat_datetime, options_list). (json stores them as lists) +
Records the history of options over time. +
Initialized with (datetime.min.isoformat(), self.options) upon question creation. +
Updated whenever options are changed.""", + ) # Legacy field that will be removed possibilities = models.JSONField(null=True, blank=True) @@ -271,6 +304,9 @@ def save(self, **kwargs): self.zero_point = None if self.type != self.QuestionType.MULTIPLE_CHOICE: self.options = None + if self.type == self.QuestionType.MULTIPLE_CHOICE and not self.options_history: + # initialize options history on first save + self.options_history = [(datetime.min.isoformat(), self.options or [])] return super().save(**kwargs) @@ -570,8 +606,11 @@ class Forecast(models.Model): ) class SourceChoices(models.TextChoices): - API = "api" - UI = "ui" + API = "api" # made via the api + UI = "ui" # made using the api + # an automatically assigned forecast + # usually this means a regular forecast was split + AUTOMATIC = "automatic" # logging the source of the forecast for data purposes source = models.CharField( @@ -580,6 +619,7 @@ class SourceChoices(models.TextChoices): null=True, choices=SourceChoices.choices, default="", + db_index=True, ) distribution_input = models.JSONField( @@ -621,15 +661,17 @@ def get_prediction_values(self) -> list[float | None]: return self.probability_yes_per_category return self.continuous_cdf - def get_pmf(self) -> list[float]: + def get_pmf(self, replace_none: bool = False) -> list[float]: """ - gets the PMF for this forecast, replacing None values with 0.0 - Not for serialization use (keep None values in that case) + gets the PMF for this forecast + replaces None values with 0.0 if replace_none is True """ # TODO: return a numpy array with NaNs instead of 0.0s if self.probability_yes: return [1 - self.probability_yes, self.probability_yes] if self.probability_yes_per_category: + if not replace_none: + return self.probability_yes_per_category return [ v or 0.0 for v in self.probability_yes_per_category ] # replace None with 0.0 @@ -704,19 +746,21 @@ def get_cdf(self) -> list[float | None] | None: return self.forecast_values return None - def get_pmf(self) -> list[float]: + def get_pmf(self, replace_none: bool = False) -> list[float | None]: """ - gets the PMF for this forecast, replacing None values with 0.0 - Not for serialization use (keep None values in that case) + gets the PMF for this forecast + replacing None values with 0.0 if replace_none is True """ # TODO: return a numpy array with NaNs instead of 0.0s # grab annotation if it exists for efficiency question_type = getattr(self, "question_type", self.question.type) - forecast_values = [ - v or 0.0 for v in self.forecast_values - ] # replace None with 0.0 + forecast_values = self.forecast_values + if question_type == Question.QuestionType.MULTIPLE_CHOICE: + if not replace_none: + return forecast_values + return [v or 0.0 for v in forecast_values] # replace None with 0.0 if question_type in QUESTION_CONTINUOUS_TYPES: - cdf: list[float] = forecast_values + cdf: list[float] = forecast_values # type: ignore pmf = [cdf[0]] for i in range(1, len(cdf)): pmf.append(cdf[i] - cdf[i - 1]) diff --git a/questions/serializers/common.py b/questions/serializers/common.py index 7d579e97eb..4777d7fcd7 100644 --- a/questions/serializers/common.py +++ b/questions/serializers/common.py @@ -17,10 +17,9 @@ AggregateForecast, Forecast, ) -from questions.serializers.aggregate_forecasts import ( - serialize_question_aggregations, -) -from questions.types import QuestionMovement +from questions.serializers.aggregate_forecasts import serialize_question_aggregations +from questions.services.multiple_choice_handlers import get_all_options_from_history +from questions.types import OptionsHistoryType, QuestionMovement from users.models import User from utils.the_math.formulas import ( get_scaled_quartiles_from_cdf, @@ -40,6 +39,7 @@ class QuestionSerializer(serializers.ModelSerializer): actual_close_time = serializers.SerializerMethodField() resolution = serializers.SerializerMethodField() spot_scoring_time = serializers.SerializerMethodField() + all_options_ever = serializers.SerializerMethodField() class Meta: model = Question @@ -58,6 +58,8 @@ class Meta: "type", # Multiple-choice Questions only "options", + "all_options_ever", + "options_history", "group_variable", # Used for Group Of Questions to determine # whether question is eligible for forecasting @@ -122,6 +124,10 @@ def get_actual_close_time(self, question: Question): return min(question.scheduled_close_time, question.actual_resolve_time) return question.scheduled_close_time + def get_all_options_ever(self, question: Question): + if question.options_history: + return get_all_options_from_history(question.options_history) + def get_resolution(self, question: Question): resolution = question.resolution @@ -226,6 +232,23 @@ class Meta(QuestionWriteSerializer.Meta): "cp_reveal_time", ) + def validate(self, data: dict): + data = super().validate(data) + + if qid := data.get("id"): + question = Question.objects.get(id=qid) + if data.get("options") != question.options: + # if there are user forecasts, we can't update options this way + if question.user_forecasts.exists(): + ValidationError( + "Cannot update options through this endpoint while there are " + "user forecasts. " + "Instead, use /api/questions/update-mc-options/ or the UI on " + "the question detail page." + ) + + return data + # TODO: add validation for updating continuous question bounds @@ -394,7 +417,7 @@ class ForecastWriteSerializer(serializers.ModelSerializer): probability_yes = serializers.FloatField(allow_null=True, required=False) probability_yes_per_category = serializers.DictField( - child=serializers.FloatField(), allow_null=True, required=False + child=serializers.FloatField(allow_null=True), allow_null=True, required=False ) continuous_cdf = serializers.ListField( child=serializers.FloatField(), @@ -435,21 +458,47 @@ def binary_validation(self, probability_yes): ) return probability_yes - def multiple_choice_validation(self, probability_yes_per_category, options): + def multiple_choice_validation( + self, + probability_yes_per_category: dict[str, float | None], + current_options: list[str], + options_history: OptionsHistoryType | None, + ): if probability_yes_per_category is None: raise serializers.ValidationError( "probability_yes_per_category is required" ) if not isinstance(probability_yes_per_category, dict): raise serializers.ValidationError("Forecast must be a dictionary") - if set(probability_yes_per_category.keys()) != set(options): - raise serializers.ValidationError("Forecast must include all options") - values = [float(probability_yes_per_category[option]) for option in options] - if not all([0.001 <= v <= 0.999 for v in values]) or not np.isclose( - sum(values), 1 - ): + if not set(current_options).issubset(set(probability_yes_per_category.keys())): + raise serializers.ValidationError( + f"Forecast must reflect current options: {current_options}" + ) + all_options = get_all_options_from_history(options_history) + if not set(probability_yes_per_category.keys()).issubset(set(all_options)): + raise serializers.ValidationError( + "Forecast contains probabilities for unknown options" + ) + + values: list[float | None] = [] + for option in all_options: + value = probability_yes_per_category.get(option, None) + if option in current_options: + if (value is None) or (not (0.001 <= value <= 0.999)): + raise serializers.ValidationError( + "Probabilities for current options must be between 0.001 and 0.999" + ) + elif value is not None: + raise serializers.ValidationError( + f"Probability for inactivate option '{option}' must be null or absent" + ) + values.append(value) + if not np.isclose(sum(filter(None, values)), 1): raise serializers.ValidationError( - "All probabilities must be between 0.001 and 0.999 and sum to 1.0" + "Forecast values must sum to 1.0. " + f"Received {probability_yes_per_category} which is interpreted as " + f"values: {values} representing {all_options} " + f"with current options {current_options}" ) return values @@ -555,7 +604,7 @@ def validate(self, data): "provided for multiple choice questions" ) data["probability_yes_per_category"] = self.multiple_choice_validation( - probability_yes_per_category, question.options + probability_yes_per_category, question.options, question.options_history ) else: # Continuous question if probability_yes or probability_yes_per_category: @@ -633,6 +682,21 @@ def serialize_question( archived_scores = question.user_archived_scores user_forecasts = question.request_user_forecasts last_forecast = user_forecasts[-1] if user_forecasts else None + # if the user has a pre-registered forecast, + # replace the current forecast and anything after it + if question.type == Question.QuestionType.MULTIPLE_CHOICE: + # Right now, Multiple Choice is the only type that can have pre-registered + # forecasts + if last_forecast and last_forecast.start_time > timezone.now(): + user_forecasts = [ + f for f in user_forecasts if f.start_time < timezone.now() + ] + if user_forecasts: + last_forecast.start_time = user_forecasts[-1].start_time + user_forecasts[-1] = last_forecast + else: + last_forecast.start_time = timezone.now() + user_forecasts = [last_forecast] if ( last_forecast and last_forecast.end_time @@ -647,11 +711,7 @@ def serialize_question( many=True, ).data, "latest": ( - MyForecastSerializer( - user_forecasts[-1], - ).data - if user_forecasts - else None + MyForecastSerializer(last_forecast).data if last_forecast else None ), "score_data": dict(), } diff --git a/questions/services/forecasts.py b/questions/services/forecasts.py index 15aba16fa3..2616dc7f09 100644 --- a/questions/services/forecasts.py +++ b/questions/services/forecasts.py @@ -1,7 +1,7 @@ import logging from collections import defaultdict -from datetime import timedelta -from typing import cast, Iterable +from datetime import datetime, timedelta, timezone as dt_timezone +from typing import cast, Iterable, Literal import sentry_sdk from django.db import transaction @@ -13,6 +13,7 @@ from posts.models import PostUserSnapshot, PostSubscription from posts.services.subscriptions import create_subscription_cp_change from posts.tasks import run_on_post_forecast +from questions.services.multiple_choice_handlers import get_all_options_from_history from scoring.models import Score from users.models import User from utils.cache import cache_per_object @@ -34,21 +35,67 @@ def create_forecast( *, - question: Question = None, - user: User = None, - continuous_cdf: list[float] = None, - probability_yes: float = None, - probability_yes_per_category: list[float] = None, - distribution_input=None, + question: Question, + user: User, + continuous_cdf: list[float] | None = None, + probability_yes: float | None = None, + probability_yes_per_category: list[float | None] | None = None, + distribution_input: dict | None = None, + end_time: datetime | None = None, + source: Forecast.SourceChoices | Literal[""] | None = None, **kwargs, ): now = timezone.now() post = question.get_post() + source = source or "" + + # delete all future-dated predictions, as this one will override them + Forecast.objects.filter(question=question, author=user, start_time__gt=now).delete() + + # if the forecast to be created is for a multiple choice question during a grace + # period, we need to agument the forecast accordingly (possibly preregister) + if question.type == Question.QuestionType.MULTIPLE_CHOICE: + if not probability_yes_per_category: + raise ValueError("probability_yes_per_category required for MC questions") + options_history = question.options_history + if options_history and len(options_history) > 1: + period_end = datetime.fromisoformat(options_history[-1][0]).replace( + tzinfo=dt_timezone.utc + ) + if period_end > now: + all_options = get_all_options_from_history(question.options_history) + prior_options = options_history[-2][1] + if end_time is None or end_time > period_end: + # create a pre-registration for the given forecast + Forecast.objects.create( + question=question, + author=user, + start_time=period_end, + end_time=end_time, + probability_yes_per_category=probability_yes_per_category, + post=post, + source=Forecast.SourceChoices.AUTOMATIC, + **kwargs, + ) + end_time = period_end + + prior_pmf: list[float | None] = [None] * len(all_options) + for i, (option, value) in enumerate( + zip(all_options, probability_yes_per_category) + ): + if value is None: + continue + if option in prior_options: + prior_pmf[i] = (prior_pmf[i] or 0.0) + value + else: + prior_pmf[-1] = (prior_pmf[-1] or 0.0) + value + probability_yes_per_category = prior_pmf forecast = Forecast.objects.create( question=question, author=user, start_time=now, + end_time=end_time, continuous_cdf=continuous_cdf, probability_yes=probability_yes, probability_yes_per_category=probability_yes_per_category, @@ -56,6 +103,7 @@ def create_forecast( distribution_input if question.type in QUESTION_CONTINUOUS_TYPES else None ), post=post, + source=source, **kwargs, ) # tidy up all forecasts diff --git a/questions/services/multiple_choice_handlers.py b/questions/services/multiple_choice_handlers.py new file mode 100644 index 0000000000..d88d7e1532 --- /dev/null +++ b/questions/services/multiple_choice_handlers.py @@ -0,0 +1,376 @@ +from datetime import datetime, timezone as dt_timezone + +from django.db import transaction +from django.db.models import Q +from django.utils import timezone + +from questions.models import Question, Forecast +from questions.types import OptionsHistoryType + +# MOVE THIS serializer imports +from rest_framework import serializers +from collections import Counter +from rest_framework.exceptions import ValidationError +from users.models import User + + +class MultipleChoiceOptionsUpdateSerializer(serializers.Serializer): + options = serializers.ListField(child=serializers.CharField(), required=True) + grace_period_end = serializers.DateTimeField(required=False) + + def validate_new_options( + self, + new_options: list[str], + options_history: OptionsHistoryType, + grace_period_end: datetime | None = None, + ): + datetime_str, current_options = options_history[-1] + ts = ( + datetime.fromisoformat(datetime_str) + .replace(tzinfo=dt_timezone.utc) + .timestamp() + ) + if new_options == current_options: # no change + return + if len(new_options) == len(current_options): # renaming + if any(v > 1 for v in Counter(new_options).values()): + ValidationError("new_options includes duplicate labels") + elif timezone.now().timestamp() < ts: + raise ValidationError("options cannot change during a grace period") + elif len(new_options) < len(current_options): # deletion + if len(new_options) < 2: + raise ValidationError("Must have 2 or more options") + if new_options[-1] != current_options[-1]: + raise ValidationError("Cannot delete last option") + if [o for o in new_options if o not in current_options]: + raise ValidationError( + "options cannot change name while some are being deleted" + ) + elif len(new_options) > len(current_options): # addition + if not grace_period_end or grace_period_end <= timezone.now(): + raise ValidationError( + "grace_period_end must be in the future if adding options" + ) + if new_options[-1] != current_options[-1]: + raise ValidationError("Cannot add option after last option") + if [o for o in current_options if o not in new_options]: + raise ValidationError( + "options cannot change name while some are being added" + ) + + def validate(self, data: dict) -> dict: + question: Question = self.context.get("question") + if not question: + raise ValidationError("question must be provided in context") + + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + raise ValidationError("question must be of multiple choice type") + + options = data.get("options") + options_history = question.options_history + if not options or not options_history: + raise ValidationError( + "updating multiple choice questions requires options " + "and question must already have options_history" + ) + + grace_period_end = data.get("grace_period_end") + self.validate_new_options(options, options_history, grace_period_end) + + return data + + +def get_all_options_from_history( + options_history: OptionsHistoryType | None, +) -> list[str]: + """Returns the list of all options ever available. The last value in the list + is always the "catch-all" option. + + example: + options_history = [ + ("2020-01-01", ["a", "b", "other"]), + ("2020-01-02", ["a", "b", "c", "other"]), + ("2020-01-03", ["a", "c", "other"]), + ] + return ["a", "b", "c", "other"] + """ + if not options_history: + raise ValueError("Cannot make master list from empty history") + designated_other_label = options_history[0][1][-1] + all_labels: list[str] = [] + for _, options in options_history: + for label in options[:-1]: + if label not in all_labels: + all_labels.append(label) + return all_labels + [designated_other_label] + + +def multiple_choice_rename_option( + question: Question, + old_option: str, + new_option: str, +) -> Question: + """ + Modifies question in place and returns it. + Renames multiple choice option in question options and options history. + """ + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + raise ValueError("Question must be multiple choice") + if not question.options or old_option not in question.options: + raise ValueError("Old option not found") + if new_option in question.options: + raise ValueError("New option already exists") + if not question.options_history: + raise ValueError("Options history is empty") + + question.options = [ + new_option if opt == old_option else opt for opt in question.options + ] + for i, (timestr, options) in enumerate(question.options_history): + question.options_history[i] = ( + timestr, + [new_option if opt == old_option else opt for opt in options], + ) + + return question + + +def multiple_choice_reorder_options( + question: Question, + new_options_order: list[str], +) -> Question: + """ + Modifies question in place and returns it. + Reorders multiple choice options in question options and options history. + Requires all options ever to be present in new_options_order. + + For now, only supports reordering if options have never changed. + """ + current_options = question.options + all_options_ever = get_all_options_from_history(question.options_history) + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + raise ValueError("Question must be multiple choice") + if not current_options: + raise ValueError("Question has no options") + if set(new_options_order) != set(all_options_ever): + raise ValueError("New order does not match existing options") + if not question.options_history: + raise ValueError("Options history is empty") + + if len(question.options_history) != 1: + # TODO: support reordering options with history changes + raise ValueError("Cannot reorder options that have changed") + + # update options history (it is only one entry long) + question.options_history[0] = (question.options_history[0][0], new_options_order) + question.options = new_options_order + question.save() + + # update user forecasts + # example forecast remap: all_options_ever = [a,b,c], new_options_order = [c,a,b] + # remap = [2,0,1] + # if a forecast is [0.2,0.3,0.5], then the new one is [0.5,0.2,0.3] + remap = [all_options_ever.index(option) for option in new_options_order] + for forecast in question.user_forecasts.all(): + forecast.probability_yes_per_category = [ + forecast.probability_yes_per_category[i] for i in remap + ] + forecast.save() + + # trigger recalculation of aggregates + from questions.services.forecasts import build_question_forecasts + + build_question_forecasts(question) + + return question + + +def multiple_choice_change_grace_period_end(*args, **kwargs): + raise NotImplementedError("multiple_choice_change_grace_period_end") + + +def multiple_choice_delete_options( + question: Question, + options_to_delete: list[str], + comment_author: User, + timestep: datetime | None = None, + comment_text: str | None = None, +) -> Question: + """ + Modifies question in place and returns it. + Deletes multiple choice options in question options. + Adds a new entry to options_history. + Slices all user forecasts at timestep. + Triggers recalculation of aggregates. + """ + if not options_to_delete: + return question + timestep = timestep or timezone.now() + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + raise ValueError("Question must be multiple choice") + if not question.options or not all( + [opt in question.options for opt in options_to_delete] + ): + raise ValueError("Option to delete not found") + if not question.options_history: + raise ValueError("Options history is empty") + + if ( + datetime.fromisoformat(question.options_history[-1][0]).replace( + tzinfo=dt_timezone.utc + ) + > timestep + ): + raise ValueError("timestep is before the last options history entry") + + # update question + new_options = [opt for opt in question.options if opt not in options_to_delete] + all_options = get_all_options_from_history(question.options_history) + + question.options = new_options + question.options_history.append((timestep.isoformat(), new_options)) + question.save() + + # update user forecasts + user_forecasts = question.user_forecasts.filter( + Q(end_time__isnull=True) | Q(end_time__gt=timestep), + start_time__lt=timestep, + ) + forecasts_to_create: list[Forecast] = [] + for forecast in user_forecasts: + # get new PMF + previous_pmf = forecast.probability_yes_per_category + if len(previous_pmf) != len(all_options): + raise ValueError( + f"Forecast {forecast.id} PMF length does not match " + f"all options {all_options}" + ) + new_pmf: list[float | None] = [None] * len(all_options) + for value, label in zip(previous_pmf, all_options): + if value is None: + continue + if label in new_options: + new_pmf[all_options.index(label)] = ( + new_pmf[all_options.index(label)] or 0.0 + ) + value + else: + new_pmf[-1] = ( + new_pmf[-1] or 0.0 + ) + value # add to catch-all last option + + # slice forecast + if forecast.start_time >= timestep: + # forecast is completely after timestep, just update PMF + forecast.probability_yes_per_category = new_pmf + continue + forecasts_to_create.append( + Forecast( + question=question, + author=forecast.author, + start_time=timestep, + end_time=forecast.end_time, + probability_yes_per_category=new_pmf, + post=forecast.post, + source=Forecast.SourceChoices.AUTOMATIC, # mark as automatic forecast + ) + ) + forecast.end_time = timestep + + with transaction.atomic(): + Forecast.objects.bulk_update( + user_forecasts, ["end_time", "probability_yes_per_category"] + ) + Forecast.objects.bulk_create(forecasts_to_create) + + # trigger recalculation of aggregates + from questions.services.forecasts import build_question_forecasts + + build_question_forecasts(question) + + # notify users that about the change + from questions.tasks import multiple_choice_delete_option_notificiations + + multiple_choice_delete_option_notificiations( + question_id=question.id, + timestep=timestep, + comment_author_id=comment_author.id, + comment_text=comment_text, + ) + + return question + + +def multiple_choice_add_options( + question: Question, + options_to_add: list[str], + grace_period_end: datetime, + comment_author: User, + timestep: datetime | None = None, + comment_text: str | None = None, +) -> Question: + """ + Modifies question in place and returns it. + Adds multiple choice options in question options. + Adds a new entry to options_history. + Terminates all user forecasts at grace_period_end. + Triggers recalculation of aggregates. + """ + if not options_to_add: + return question + timestep = timestep or timezone.now() + if question.type != Question.QuestionType.MULTIPLE_CHOICE: + raise ValueError("Question must be multiple choice") + if not question.options or any([opt in question.options for opt in options_to_add]): + raise ValueError("Option to add already found") + if not question.options_history: + raise ValueError("Options history is empty") + + if timestep > grace_period_end: + raise ValueError("grace_period_end must end after timestep") + if ( + datetime.fromisoformat(question.options_history[-1][0]).replace( + tzinfo=dt_timezone.utc + ) + > timestep + ): + raise ValueError("timestep is before the last options history entry") + + # update question + new_options = question.options[:-1] + options_to_add + question.options[-1:] + question.options = new_options + question.options_history.append((grace_period_end.isoformat(), new_options)) + question.save() + + # update user forecasts + user_forecasts = question.user_forecasts.all() + for forecast in user_forecasts: + pmf = forecast.probability_yes_per_category + forecast.probability_yes_per_category = ( + pmf[:-1] + [None] * len(options_to_add) + [pmf[-1]] + ) + if forecast.start_time < grace_period_end and ( + forecast.end_time is None or forecast.end_time > grace_period_end + ): + forecast.end_time = grace_period_end + with transaction.atomic(): + Forecast.objects.bulk_update( + user_forecasts, ["probability_yes_per_category", "end_time"] + ) + + # trigger recalculation of aggregates + from questions.services.forecasts import build_question_forecasts + + build_question_forecasts(question) + + # notify users that about the change + from questions.tasks import multiple_choice_add_option_notificiations + + multiple_choice_add_option_notificiations( + question_id=question.id, + grace_period_end=grace_period_end, + timestep=timestep, + comment_author_id=comment_author.id, + comment_text=comment_text, + ) + + return question diff --git a/questions/tasks.py b/questions/tasks.py index 988b7e0fbe..74caf247f1 100644 --- a/questions/tasks.py +++ b/questions/tasks.py @@ -1,10 +1,12 @@ import logging -from datetime import timedelta +from datetime import datetime, timedelta import dramatiq +from django.conf import settings from django.db.models import Q from django.utils import timezone +from comments.services.common import create_comment from notifications.constants import MailingTags from notifications.services import ( NotificationPredictedQuestionResolved, @@ -15,17 +17,18 @@ ) from posts.models import Post from posts.services.subscriptions import notify_post_status_change +from questions.models import Forecast, Question, UserForecastNotification +from questions.services.common import get_outbound_question_links +from questions.services.forecasts import ( + build_question_forecasts, + get_forecasts_per_user, +) from scoring.constants import ScoreTypes from scoring.utils import score_question from users.models import User from utils.dramatiq import concurrency_retries, task_concurrent_limit +from utils.email import send_email_with_template from utils.frontend import build_frontend_account_settings_url, build_post_url -from .models import Question, UserForecastNotification -from .services.common import get_outbound_question_links -from .services.forecasts import ( - build_question_forecasts, - get_forecasts_per_user, -) @dramatiq.actor(max_backoff=10_000, retry_when=concurrency_retries(max_retries=20)) @@ -255,3 +258,160 @@ def format_time_remaining(time_remaining: timedelta): return f"{minutes} minute{'s' if minutes != 1 else ''}" else: return f"{total_seconds} second{'s' if total_seconds != 1 else ''}" + + +@dramatiq.actor +def multiple_choice_delete_option_notificiations( + question_id: int, + timestep: datetime, + comment_author_id: int, + comment_text: str | None = None, +): + question = Question.objects.get(id=question_id) + post = question.get_post() + options_history = question.options_history + removed_options = list(set(options_history[-2][1]) - set(options_history[-1][1])) + + # send out a comment + comment_author = User.objects.get(id=comment_author_id) + default_text = ( + "Options {removed_options} were removed at {timestep}. " + "Forecasts were adjusted to keep remaining probability on the catch-all." + ) + template = comment_text or default_text + try: + text = template.format(removed_options=removed_options, timestep=timestep) + except Exception: + text = f"{template} (removed options: {removed_options}, at {timestep})" + + create_comment(comment_author, post, text=text) + + forecasters = ( + User.objects.filter( + forecast__in=question.user_forecasts.filter( + Q(end_time__isnull=True) | Q(end_time__gt=timestep) + ) + ) + .exclude( + unsubscribed_mailing_tags__contains=[ + MailingTags.BEFORE_PREDICTION_AUTO_WITHDRAWAL # seems most reasonable + ] + ) + .exclude(email__isnull=True) + .exclude(email="") + .distinct("id") + .order_by("id") + ) + # send out an immediate email + for forecaster in forecasters: + send_email_with_template( + to=forecaster.email, + subject="Multiple choice option removed", + template_name="emails/multiple_choice_option_deletion.html", + context={ + "recipient": forecaster, + "email_subject_display": "Multiple choice option removed", + "similar_posts": [], + "params": { + "post": NotificationPostParams.from_post(post), + "removed_options": removed_options, + "timestep": timestep, + }, + }, + use_async=False, + from_email=settings.EMAIL_NOTIFICATIONS_USER, + ) + + +@dramatiq.actor +def multiple_choice_add_option_notificiations( + question_id: int, + grace_period_end: datetime, + timestep: datetime, + comment_author_id: int, + comment_text: str | None = None, +): + question = Question.objects.get(id=question_id) + post = question.get_post() + options_history = question.options_history + added_options = list(set(options_history[-1][1]) - set(options_history[-2][1])) + + # send out a comment + comment_author = User.objects.get(id=comment_author_id) + default_text = ( + "Options {added_options} were added at {timestep}. " + "Please update forecasts before {grace_period_end}, when existing " + "forecasts will auto-withdraw." + ) + template = comment_text or default_text + try: + text = template.format( + added_options=added_options, + timestep=timestep, + grace_period_end=grace_period_end, + ) + except Exception: + text = ( + f"{template} (added options: {added_options}, at {timestep}, " + f"grace ends: {grace_period_end})" + ) + + create_comment(comment_author, post, text=text) + + forecasters = ( + User.objects.filter( + forecast__in=question.user_forecasts.filter( + end_time=grace_period_end + ) # all effected forecasts have their end_time set to grace_period_end + ) + .exclude( + unsubscribed_mailing_tags__contains=[ + MailingTags.BEFORE_PREDICTION_AUTO_WITHDRAWAL # seems most reasonable + ] + ) + .exclude(email__isnull=True) + .exclude(email="") + .distinct("id") + .order_by("id") + ) + # send out an immediate email + for forecaster in forecasters: + send_email_with_template( + to=forecaster.email, + subject="Multiple choice options added", + template_name="emails/multiple_choice_option_addition.html", + context={ + "recipient": forecaster, + "email_subject_display": "Multiple choice options added", + "similar_posts": [], + "params": { + "post": NotificationPostParams.from_post(post), + "added_options": added_options, + "grace_period_end": grace_period_end, + "timestep": timestep, + }, + }, + use_async=False, + from_email=settings.EMAIL_NOTIFICATIONS_USER, + ) + + # schedule a followup email for 1 day before grace period + # (if grace period is more than 1 day away) + if grace_period_end - timedelta(days=1) > timestep: + for forecaster in forecasters: + UserForecastNotification.objects.filter( + user=forecaster, question=question + ).delete() # is this necessary? + UserForecastNotification.objects.update_or_create( + user=forecaster, + question=question, + defaults={ + "trigger_time": grace_period_end - timedelta(days=1), + "email_sent": False, + "forecast": Forecast.objects.filter( + question=question, author=forecaster + ) + .order_by("-start_time") + .first(), + }, + ) diff --git a/questions/types.py b/questions/types.py index 9556806b41..f87735e520 100644 --- a/questions/types.py +++ b/questions/types.py @@ -3,6 +3,8 @@ from django.db import models from django.db.models import TextChoices +OptionsHistoryType = list[tuple[str, list[str]]] + class Direction(TextChoices): UNCHANGED = "unchanged" diff --git a/scoring/score_math.py b/scoring/score_math.py index fada04f0d1..546b19d310 100644 --- a/scoring/score_math.py +++ b/scoring/score_math.py @@ -20,7 +20,7 @@ @dataclass class AggregationEntry: - pmf: np.ndarray | list[float] + pmf: np.ndarray | list[float | None] num_forecasters: int timestamp: float @@ -36,7 +36,7 @@ def get_geometric_means( timesteps.add(forecast.end_time.timestamp()) for timestep in sorted(timesteps): prediction_values = [ - f.get_pmf() + f.get_pmf(replace_none=True) for f in forecasts if f.start_time.timestamp() <= timestep and (f.end_time is None or f.end_time.timestamp() > timestep) @@ -84,9 +84,12 @@ def evaluate_forecasts_baseline_accuracy( forecast_coverage = forecast_duration / total_duration pmf = forecast.get_pmf() if question_type in ["binary", "multiple_choice"]: - forecast_score = ( - 100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf)) - ) + # forecasts always have `None` assigned to MC options that aren't + # available at the time. Detecting these allows us to avoid trying to + # follow the question's options_history. + options_at_time = len([p for p in pmf if p is not None]) + p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other + forecast_score = 100 * np.log(p * options_at_time) / np.log(options_at_time) else: if resolution_bucket in [0, len(pmf) - 1]: baseline = 0.05 @@ -116,8 +119,13 @@ def evaluate_forecasts_baseline_spot_forecast( if start <= spot_forecast_timestamp < end: pmf = forecast.get_pmf() if question_type in ["binary", "multiple_choice"]: + # forecasts always have `None` assigned to MC options that aren't + # available at the time. Detecting these allows us to avoid trying to + # follow the question's options_history. + options_at_time = len([p for p in pmf if p is not None]) + p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other forecast_score = ( - 100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf)) + 100 * np.log(p * options_at_time) / np.log(options_at_time) ) else: if resolution_bucket in [0, len(pmf) - 1]: @@ -159,17 +167,21 @@ def evaluate_forecasts_peer_accuracy( continue pmf = forecast.get_pmf() + p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other interval_scores: list[float | None] = [] for gm in geometric_mean_forecasts: if forecast_start <= gm.timestamp < forecast_end: - score = ( + gmp = ( + gm.pmf[resolution_bucket] or gm.pmf[-1] + ) # if None, read from Other + interval_score = ( 100 * (gm.num_forecasters / (gm.num_forecasters - 1)) - * np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket]) + * np.log(p / gmp) ) if question_type in QUESTION_CONTINUOUS_TYPES: - score /= 2 - interval_scores.append(score) + interval_score /= 2 + interval_scores.append(interval_score) else: interval_scores.append(None) @@ -218,10 +230,10 @@ def evaluate_forecasts_peer_spot_forecast( ) if start <= spot_forecast_timestamp < end: pmf = forecast.get_pmf() + p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other + gmp = gm.pmf[resolution_bucket] or gm.pmf[-1] # if None, read from Other forecast_score = ( - 100 - * (gm.num_forecasters / (gm.num_forecasters - 1)) - * np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket]) + 100 * (gm.num_forecasters / (gm.num_forecasters - 1)) * np.log(p / gmp) ) if question_type in QUESTION_CONTINUOUS_TYPES: forecast_score /= 2 @@ -260,11 +272,15 @@ def evaluate_forecasts_legacy_relative( continue pmf = forecast.get_pmf() + p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other interval_scores: list[float | None] = [] for bf in baseline_forecasts: if forecast_start <= bf.timestamp < forecast_end: - score = np.log2(pmf[resolution_bucket] / bf.pmf[resolution_bucket]) - interval_scores.append(score) + bfp = ( + bf.pmf[resolution_bucket] or bf.pmf[-1] + ) # if None, read from Other + interval_score = np.log2(p / bfp) + interval_scores.append(interval_score) else: interval_scores.append(None) @@ -316,7 +332,7 @@ def evaluate_question( if spot_forecast_time: spot_forecast_timestamp = min(spot_forecast_time.timestamp(), actual_close_time) - # We need all user forecasts to calculated GeoMean even + # We need all user forecasts to calculate GeoMean even # if we're only scoring some or none of the users user_forecasts = question.user_forecasts.all() if only_include_user_ids: diff --git a/templates/admin/questions/update_options.html b/templates/admin/questions/update_options.html new file mode 100644 index 0000000000..7d7a426248 --- /dev/null +++ b/templates/admin/questions/update_options.html @@ -0,0 +1,182 @@ +{% extends "admin/base_site.html" %} +{% load i18n admin_urls %} + +{% block extrahead %} +{{ block.super }} +{{ media }} +{% endblock %} + +{% block breadcrumbs %} + +{% endblock %} + +{% block content %} + +
+

Current options:

+
    + {% for opt in current_options %} +
  • {{ opt }}
  • + {% endfor %} +
+

All options ever used:

+
    + {% for opt in all_history_options %} +
  • {{ opt }}
  • + {% endfor %} +
+ + {% if in_grace_period %} +
+ This question is in an active grace period until {{ grace_period_end|date:"DATETIME_FORMAT" }}. + You can rename options, change the grace period end, but adding or deleting options is temporarily disabled. +
+ {% endif %} + +
+ {% csrf_token %} + {{ form.non_field_errors }} +
+ {% for field in form %} +
+ {{ field.errors }} +
+ {{ field.label_tag }} + {{ field }} + {% if field.help_text %}

{{ field.help_text }}

{% endif %} +
+
+ {% endfor %} +
+ + + +
+
+ +{% endblock %} diff --git a/tests/unit/test_comments/test_views.py b/tests/unit/test_comments/test_views.py index 18293193e5..b5f50aaef3 100644 --- a/tests/unit/test_comments/test_views.py +++ b/tests/unit/test_comments/test_views.py @@ -14,6 +14,7 @@ KeyFactorNews, ) from comments.services.feed import get_comments_feed +from questions.models import Forecast from questions.services.forecasts import create_forecast from tests.unit.test_comments.factories import factory_comment, factory_key_factor from tests.unit.test_misc.factories import factory_itn_article @@ -689,26 +690,23 @@ def test_comment_edit_include_forecast(self, user1, user1_client, question_binar # 0. Forecast created and closed before comment creation t_forecast_expired_start = now - timedelta(hours=4) t_forecast_expired_end = now - timedelta(hours=3) - - with freeze_time(t_forecast_expired_start): - forecast_expired = create_forecast( - question=question, - user=user1, - probability_yes=0.2, - ) - - forecast_expired.end_time = t_forecast_expired_end - forecast_expired.save() + Forecast.objects.create( + question=question, + author=user1, + probability_yes=0.2, + start_time=t_forecast_expired_start, + end_time=t_forecast_expired_end, + ) # 1. Forecast active at comment creation t_forecast_1 = now - timedelta(hours=2) - - with freeze_time(t_forecast_1): - forecast_1 = create_forecast( - question=question, - user=user1, - probability_yes=0.5, - ) + forecast_1 = Forecast.objects.create( + question=question, + author=user1, + probability_yes=0.5, + start_time=t_forecast_1, + end_time=None, + ) # 2. Comment created later. t_comment = now - timedelta(hours=1) @@ -723,12 +721,15 @@ def test_comment_edit_include_forecast(self, user1, user1_client, question_binar # 3. New forecast created after comment t_forecast_2 = now - timedelta(minutes=30) - with freeze_time(t_forecast_2): - create_forecast( - question=question, - user=user1, - probability_yes=0.8, - ) + forecast_2 = Forecast.objects.create( + question=question, + author=user1, + probability_yes=0.8, + start_time=t_forecast_2, + end_time=None, + ) + forecast_1.end_time = forecast_2.start_time + forecast_1.save() # 4. Edit comment to include forecast url = reverse("comment-edit", kwargs={"pk": comment.pk}) @@ -743,12 +744,14 @@ def test_comment_edit_include_forecast(self, user1, user1_client, question_binar assert comment.included_forecast == forecast_1 # 5. Prevent overwrite if already set - with freeze_time(now): - create_forecast( - question=question, - user=user1, - probability_yes=0.9, - ) + forecast_3 = Forecast.objects.create( + question=question, + author=user1, + probability_yes=0.9, + start_time=now, + ) + forecast_2.end_time = forecast_3.start_time + forecast_2.save() # Even if we pass include_forecast=True again, it shouldn't change response = user1_client.post( @@ -779,16 +782,13 @@ def test_comment_edit_include_forecast(self, user1, user1_client, question_binar # 8. Test attaching when multiple forecasts exist before creation t_forecast_0 = now - timedelta(hours=3) - with freeze_time(t_forecast_0): - forecast_0 = create_forecast( - question=question, - user=user1, - probability_yes=0.1, - ) - - # Close it before comment creation - forecast_0.end_time = t_forecast_1 - forecast_0.save() + Forecast.objects.create( + question=question, + author=user1, + probability_yes=0.1, + start_time=t_forecast_0, + end_time=forecast_1.start_time, + ) # So at t_comment, forecast_0 is closed. Forecast_1 is open. response = user1_client.post( diff --git a/tests/unit/test_questions/conftest.py b/tests/unit/test_questions/conftest.py index 7f7ab29e4f..57ebbb3d20 100644 --- a/tests/unit/test_questions/conftest.py +++ b/tests/unit/test_questions/conftest.py @@ -9,6 +9,7 @@ __all__ = [ "question_binary", + "question_multiple_choice", "question_numeric", "conditional_1", "question_binary_with_forecast_user_1", @@ -28,6 +29,7 @@ def question_multiple_choice(): return create_question( question_type=Question.QuestionType.MULTIPLE_CHOICE, options=["a", "b", "c", "d"], + options_history=[("0001-01-01T00:00:00", ["a", "b", "c", "d"])], ) diff --git a/tests/unit/test_questions/test_models.py b/tests/unit/test_questions/test_models.py index ba405474ab..74c5e49b3f 100644 --- a/tests/unit/test_questions/test_models.py +++ b/tests/unit/test_questions/test_models.py @@ -43,3 +43,14 @@ def test_filter_within_question_period( Forecast.objects.filter(id=f1.id).filter_within_question_period().exists() == include ) + + +def test_initialize_multiple_choice_question(): + question = create_question( + question_type=Question.QuestionType.MULTIPLE_CHOICE, + options=["a", "b", "other"], + ) + question.save() + assert ( + question.options_history and question.options_history[0][1] == question.options + ) diff --git a/tests/unit/test_questions/test_services.py b/tests/unit/test_questions/test_services/test_lifecycle.py similarity index 100% rename from tests/unit/test_questions/test_services.py rename to tests/unit/test_questions/test_services/test_lifecycle.py diff --git a/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py b/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py new file mode 100644 index 0000000000..7341884abe --- /dev/null +++ b/tests/unit/test_questions/test_services/test_multiple_choice_handlers.py @@ -0,0 +1,456 @@ +from datetime import datetime + +import pytest # noqa + +from questions.models import Question, Forecast +from questions.services.multiple_choice_handlers import ( + multiple_choice_add_options, + multiple_choice_delete_options, + multiple_choice_rename_option, + multiple_choice_reorder_options, +) +from tests.unit.test_posts.factories import factory_post +from tests.unit.utils import datetime_aware as dt +from users.models import User + + +@pytest.mark.parametrize( + "old_option,new_option,expect_success", + [ + ("Option B", "Option D", True), + ("Option X", "Option Y", False), # old_option does not exist + ("Option A", "Option A", False), # new_option already exists + ], +) +def test_multiple_choice_rename_option( + question_multiple_choice, old_option, new_option, expect_success +): + question = question_multiple_choice + question.options = ["Option A", "Option B", "Option C"] + question.save() + + if not expect_success: + with pytest.raises(ValueError): + multiple_choice_rename_option(question, old_option, new_option) + return + updated_question = multiple_choice_rename_option(question, old_option, new_option) + + assert old_option not in updated_question.options + assert new_option in updated_question.options + assert len(updated_question.options) == 3 + + +@pytest.mark.parametrize( + "new_options_order,expect_success", + [ + (["Option A", "Option B", "Option C"], True), # no change + (["Option C", "Option B", "Option A"], True), # happy path + (["Option B", "Option A"], False), # different number of options + ( + ["Option A", "Option B", "Option C", "D"], + False, + ), # different number of options + (["Option D", "Option E", "Option F"], False), # different options + ], +) +def test_multiple_choice_reorder_options( + question_multiple_choice, user1, new_options_order, expect_success +): + question = question_multiple_choice + original_options = ["Option A", "Option B", "Option C"] + question.options = original_options + question.options_history = [(datetime.min.isoformat(), original_options)] + question.save() + Forecast.objects.create( + author=user1, + question=question, + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + + if not expect_success: + with pytest.raises(ValueError): + multiple_choice_reorder_options(question, new_options_order) + return + updated_question = multiple_choice_reorder_options(question, new_options_order) + + assert updated_question.options == new_options_order + forecast = updated_question.user_forecasts.first() + assert forecast is not None + assert forecast.probability_yes_per_category == [ + [0.2, 0.3, 0.5][original_options.index(opt)] for opt in new_options_order + ] + + +@pytest.mark.parametrize( + "initial_options,options_to_delete,forecasts,expected_forecasts,expect_success", + [ + (["a", "b", "other"], ["b"], [], [], True), # simplest path + (["a", "b", "other"], ["c"], [], [], False), # try to remove absent item + (["a", "b", "other"], ["a", "b"], [], [], True), # remove two items + ( + ["a", "b", "other"], + ["b"], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, 0.5], + ), + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, None, 0.8], + source=Forecast.SourceChoices.AUTOMATIC, + ), + ], + True, + ), # happy path + ( + ["a", "b", "c", "other"], + ["b", "c"], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.1, 0.4], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, 0.1, 0.4], + ), + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, None, None, 0.8], + source=Forecast.SourceChoices.AUTOMATIC, + ), + ], + True, + ), # happy path removing 2 + ( + ["a", "b", "other"], + ["b"], + [ + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.8], + ) + ], + [ + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.8], + ), + ], + True, + ), # forecast is at / after timestep + ( + ["a", "b", "other"], + [], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + True, + ), # no effect + ( + ["a", "b", "other"], + ["b"], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.8], + ) + ], + [], + False, + ), # initial forecast is invalid + ( + ["a", "b", "other"], + ["b"], + [ + Forecast( + start_time=dt(2023, 1, 1), + end_time=dt(2024, 1, 1), + probability_yes_per_category=[0.6, 0.15, 0.25], + ), + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ), + ], + [ + Forecast( + start_time=dt(2023, 1, 1), + end_time=dt(2024, 1, 1), + probability_yes_per_category=[0.6, 0.15, 0.25], + ), + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, 0.5], + ), + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, None, 0.8], + source=Forecast.SourceChoices.AUTOMATIC, + ), + ], + True, + ), # preserve previous forecasts + ], +) +def test_multiple_choice_delete_options( + question_multiple_choice: Question, + user1: User, + initial_options: list[str], + options_to_delete: list[str], + forecasts: list[Forecast], + expected_forecasts: list[Forecast], + expect_success: bool, +): + question = question_multiple_choice + question.options = initial_options + question.options_history = [(datetime.min.isoformat(), initial_options)] + question.save() + factory_post(question=question) + + timestep = dt(2025, 1, 1) + for forecast in forecasts: + forecast.author = user1 + forecast.question = question + forecast.save() + + if not expect_success: + with pytest.raises(ValueError): + multiple_choice_delete_options( + question, options_to_delete, comment_author=user1, timestep=timestep + ) + return + + multiple_choice_delete_options( + question, options_to_delete, comment_author=user1, timestep=timestep + ) + + question.refresh_from_db() + expected_options = [opt for opt in initial_options if opt not in options_to_delete] + assert question.options == expected_options + ts, options = question.options_history[-1] + assert ts == ( + timestep.isoformat() if options_to_delete else datetime.min.isoformat() + ) + assert options == expected_options + + forecasts = question.user_forecasts.order_by("start_time") + assert len(forecasts) == len(expected_forecasts) + for f, e in zip(forecasts, expected_forecasts): + assert f.start_time == e.start_time + assert f.end_time == e.end_time + assert f.probability_yes_per_category == e.probability_yes_per_category + assert f.source == e.source + + +@pytest.mark.parametrize( + "initial_options,options_to_add,grace_period_end,forecasts,expected_forecasts," + "expect_success", + [ + (["a", "b", "other"], ["c"], dt(2025, 1, 1), [], [], True), # simplest path + (["a", "b", "other"], ["b"], dt(2025, 1, 1), [], [], False), # copied add + (["a", "b", "other"], ["c", "d"], dt(2025, 1, 1), [], [], True), # double add + # grace period before last options history + (["a", "b", "other"], ["c"], dt(1900, 1, 1), [], [], False), + ( + ["a", "b", "other"], + ["c"], + dt(2025, 1, 1), + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, None, 0.5], + ) + ], + True, + ), # happy path + ( + ["a", "b", "other"], + ["c", "d"], + dt(2025, 1, 1), + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, None, None, 0.5], + ) + ], + True, + ), # happy path adding two options + ( + ["a", "b", "other"], + ["c"], + dt(2025, 1, 1), + [ + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2025, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, None, 0.5], + ) + ], + True, + ), # forecast starts at /after grace_period_end + ( + ["a", "b", "other"], + [], + dt(2025, 1, 1), + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + [ + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ) + ], + True, + ), # no effect + ( + ["a", "b", "other"], + ["c"], + dt(2025, 1, 1), + [ + Forecast( + start_time=dt(2023, 1, 1), + end_time=dt(2024, 1, 1), + probability_yes_per_category=[0.6, 0.15, 0.25], + ), + Forecast( + start_time=dt(2024, 1, 1), + end_time=None, + probability_yes_per_category=[0.2, 0.3, 0.5], + ), + ], + [ + Forecast( + start_time=dt(2023, 1, 1), + end_time=dt(2024, 1, 1), + probability_yes_per_category=[0.6, 0.15, None, 0.25], + ), + Forecast( + start_time=dt(2024, 1, 1), + end_time=dt(2025, 1, 1), + probability_yes_per_category=[0.2, 0.3, None, 0.5], + ), + ], + True, + ), # edit all forecasts including old + ], +) +def test_multiple_choice_add_options( + question_multiple_choice: Question, + user1: User, + initial_options: list[str], + options_to_add: list[str], + grace_period_end: datetime, + forecasts: list[Forecast], + expected_forecasts: list[Forecast], + expect_success: bool, +): + question = question_multiple_choice + question.options = initial_options + question.options_history = [(datetime.min.isoformat(), initial_options)] + question.save() + factory_post(question=question) + + for forecast in forecasts: + forecast.author = user1 + forecast.question = question + forecast.save() + + if not expect_success: + with pytest.raises(ValueError): + multiple_choice_add_options( + question, + options_to_add, + grace_period_end, + comment_author=user1, + timestep=dt(2024, 7, 1), + ) + return + + multiple_choice_add_options( + question, + options_to_add, + grace_period_end, + comment_author=user1, + timestep=dt(2024, 7, 1), + ) + + question.refresh_from_db() + expected_options = initial_options[:-1] + options_to_add + initial_options[-1:] + assert question.options == expected_options + ts, options = question.options_history[-1] + assert ts == ( + grace_period_end.isoformat() if options_to_add else datetime.min.isoformat() + ) + assert options == expected_options + + forecasts = question.user_forecasts.order_by("start_time") + assert len(forecasts) == len(expected_forecasts) + for f, e in zip(forecasts, expected_forecasts): + assert f.start_time == e.start_time + assert f.end_time == e.end_time + assert f.probability_yes_per_category == e.probability_yes_per_category + assert f.source == e.source diff --git a/tests/unit/test_questions/test_views.py b/tests/unit/test_questions/test_views.py index 2f009b1452..3e75a4f275 100644 --- a/tests/unit/test_questions/test_views.py +++ b/tests/unit/test_questions/test_views.py @@ -10,11 +10,13 @@ from posts.models import Post from questions.models import Forecast, Question, UserForecastNotification +from questions.types import OptionsHistoryType from questions.tasks import check_and_schedule_forecast_widrawal_due_notifications from tests.unit.test_posts.conftest import * # noqa from tests.unit.test_posts.factories import factory_post from tests.unit.test_questions.conftest import * # noqa from tests.unit.test_questions.factories import create_question +from users.models import User class TestQuestionForecast: @@ -75,30 +77,173 @@ def test_forecast_binary_invalid(self, post_binary_public, user1_client, props): ) assert response.status_code == 400 + @freeze_time("2025-01-01") @pytest.mark.parametrize( - "props", + "options_history,forecast_props,expected", [ - {"probability_yes_per_category": {"a": 0.1, "b": 0.2, "c": 0.3, "d": 0.4}}, + ( + [("0001-01-01T00:00:00", ["a", "other"])], + { + "probability_yes_per_category": { + "a": 0.6, + "other": 0.4, + }, + "end_time": "2026-01-01", + }, + [ + Forecast( + probability_yes_per_category=[0.6, 0.4], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + ], + ), # simple path + ( + [("0001-01-01T00:00:00", ["a", "b", "other"])], + { + "probability_yes_per_category": { + "a": 0.6, + "b": 0.15, + "other": 0.25, + }, + "end_time": "2026-01-01", + }, + [ + Forecast( + probability_yes_per_category=[0.6, 0.15, 0.25], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + ], + ), # simple path 3 options + ( + [ + ("0001-01-01T00:00:00", ["a", "b", "other"]), + (datetime(2024, 1, 1).isoformat(), ["a", "other"]), + ], + { + "probability_yes_per_category": { + "a": 0.6, + "other": 0.4, + }, + "end_time": "2026-01-01", + }, + [ + Forecast( + probability_yes_per_category=[0.6, None, 0.4], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + ], + ), # option deletion + ( + [ + ("0001-01-01T00:00:00", ["a", "b", "other"]), + (datetime(2024, 1, 1).isoformat(), ["a", "b", "c", "other"]), + ], + { + "probability_yes_per_category": { + "a": 0.6, + "b": 0.15, + "c": 0.20, + "other": 0.05, + }, + "end_time": "2026-01-01", + }, + [ + Forecast( + probability_yes_per_category=[0.6, 0.15, 0.20, 0.05], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + ], + ), # option addition + ( + [ + ("0001-01-01T00:00:00", ["a", "b", "other"]), + (datetime(2026, 1, 1).isoformat(), ["a", "b", "c", "other"]), + ], + { + "probability_yes_per_category": { + "a": 0.6, + "b": 0.15, + "c": 0.20, + "other": 0.05, + }, + }, + [ + Forecast( + probability_yes_per_category=[0.6, 0.15, None, 0.25], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + Forecast( + probability_yes_per_category=[0.6, 0.15, 0.20, 0.05], + start_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + end_time=None, + source=Forecast.SourceChoices.AUTOMATIC, + ), + ], + ), # forecasting during a grace period + ( + [ + ("0001-01-01T00:00:00", ["a", "b", "other"]), + (datetime(2026, 1, 1).isoformat(), ["a", "b", "c", "other"]), + ], + { + "probability_yes_per_category": { + "a": 0.6, + "b": 0.15, + "c": 0.20, + "other": 0.05, + }, + "end_time": "2027-01-01", + }, + [ + Forecast( + probability_yes_per_category=[0.6, 0.15, None, 0.25], + start_time=datetime(2025, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + ), + Forecast( + probability_yes_per_category=[0.6, 0.15, 0.20, 0.05], + start_time=datetime(2026, 1, 1, tzinfo=dt_timezone.utc), + end_time=datetime(2027, 1, 1, tzinfo=dt_timezone.utc), + source=Forecast.SourceChoices.AUTOMATIC, + ), + ], + ), # forecasting during a grace period with end time ], ) def test_forecast_multiple_choice( - self, post_multiple_choice_public, user1, user1_client, props + self, + post_multiple_choice_public: Post, + user1: User, + user1_client, + options_history: OptionsHistoryType, + forecast_props: dict, + expected: list[Forecast], ): + question = post_multiple_choice_public.question + question.options_history = options_history + question.options = options_history[-1][1] + question.save() response = user1_client.post( self.url, - data=json.dumps( - [{"question": post_multiple_choice_public.question.id, **props}] - ), + data=json.dumps([{"question": question.id, **forecast_props}]), content_type="application/json", ) assert response.status_code == 201 - forecast = Forecast.objects.filter( - question=post_multiple_choice_public.question, author=user1 - ).first() - assert forecast - assert forecast.probability_yes_per_category == list( - props.get("probability_yes_per_category").values() - ) + forecasts = Forecast.objects.filter( + question=post_multiple_choice_public.question, + author=user1, + ).order_by("start_time") + assert len(forecasts) == len(expected) + for f, e in zip(forecasts, expected): + assert f.start_time == e.start_time + assert f.end_time == e.end_time + assert f.probability_yes_per_category == e.probability_yes_per_category + assert f.source == e.source @pytest.mark.parametrize( "props", diff --git a/tests/unit/test_scoring/test_score_math.py b/tests/unit/test_scoring/test_score_math.py index 23f5f78c71..652dcd9be3 100644 --- a/tests/unit/test_scoring/test_score_math.py +++ b/tests/unit/test_scoring/test_score_math.py @@ -47,7 +47,7 @@ def F(q=None, v=None, s=None, e=None): return forecast -def A(p: list[float] | None = None, n: int = 0, t: int | None = None): +def A(p: list[float | None] | None = None, n: int = 0, t: int | None = None): # Create an AggregationEntry object with basic values # p: pmf # n: number of forecasters @@ -75,6 +75,11 @@ class TestScoreMath: ([F()] * 100, [A(n=100)]), # maths ([F(v=0.7), F(v=0.8), F(v=0.9)], [A(p=[0.18171206, 0.79581144], n=3)]), + # multiple choice forecasts with placeholder 0s + ( + [F(q=QT.MULTIPLE_CHOICE, v=[0.6, 0.15, None, 0.25])] * 2, + [A(n=2, p=[0.6, 0.15, 0.0, 0.25])], + ), # start times ([F(), F(s=1)], [A(), A(t=1, n=2)]), ([F(), F(s=1), F(s=2)], [A(), A(t=1, n=2), A(t=2, n=3)]), @@ -85,7 +90,7 @@ class TestScoreMath: # numeric ( [F(q=QT.NUMERIC), F(q=QT.NUMERIC)], - [A(p=[0] + [1 / 200] * 200 + [0], n=2)], + [A(p=[0.0] + [1 / 200] * 200 + [0.0], n=2)], ), ( [ @@ -103,7 +108,10 @@ def test_get_geometric_means( result = get_geometric_means(forecasts) assert len(result) == len(expected) for ra, ea in zip(result, expected): - assert all(round(r, 8) == round(e, 8) for r, e in zip(ra.pmf, ea.pmf)) + assert all( + ((r == e) or (round(r, 8) == round(e, 8))) + for r, e in zip(ra.pmf, ea.pmf) + ) assert ra.num_forecasters == ea.num_forecasters assert ra.timestamp == ea.timestamp @@ -131,6 +139,37 @@ def test_get_geometric_means( ([F(v=0.9, s=5)], {}, [S(v=84.79969066 / 2, c=0.5)]), # half coverage ([F(v=2 ** (-1 / 2))], {}, [S(v=50)]), ([F(v=2 ** (-3 / 2))], {}, [S(v=-50)]), + # multiple choice w/ placeholder at index 2 + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=0.0)], + ), # chosen to have a score of 0 for simplicity + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=50)], + ), # same score as index == 3 since None should read from "Other" + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=50)], + ), # chosen to have a score of 50 for simplicity # numeric ( [F(q=QT.NUMERIC)], @@ -199,6 +238,37 @@ def test_evaluate_forecasts_baseline_accuracy(self, forecasts, args, expected): ([F(v=0.9, s=5)], {}, [S(v=84.79969066, c=1)]), ([F(v=2 ** (-1 / 2))], {}, [S(v=50)]), ([F(v=2 ** (-3 / 2))], {}, [S(v=-50)]), + # multiple choice w/ placeholder at index 2 + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=0.0)], + ), # chosen to have a score of 0 for simplicity + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=50)], + ), # same score as index == 3 since None should read from "Other" + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 - 3 ** (-0.5) - 1 / 3, None, 3 ** (-0.5)], + ) + ], + {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=50)], + ), # chosen to have a score of 50 for simplicity # numeric ( [F(q=QT.NUMERIC)], @@ -319,6 +389,64 @@ def test_evaluate_forecasts_baseline_spot_forecast(self, forecasts, args, expect S(v=100 * (0.5 * 0 + 0.5 * np.log(0.9 / gmean([0.1, 0.5]))), c=0.5), ], ), + # multiple choice w/ placeholder at index 2 + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=0), S(v=0)], + ), # chosen to have a score of 0 for simplicity + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=25), S(v=-25)], + ), # same score as index == 3 since 0.0 should read from "Other" + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=25), S(v=-25)], + ), # chosen to have a score of 25 for simplicity # TODO: add tests with base forecasts different from forecasts ], ) @@ -403,6 +531,64 @@ def test_evaluate_forecasts_peer_accuracy(self, forecasts, args, expected): {}, [S(v=100 * np.log(0.1 / 0.5)), S(v=100 * np.log(0.5 / 0.1)), S(c=0)], ), + # multiple choice w/ placeholder at index 2 + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 0, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=0), S(v=0)], + ), # chosen to have a score of 0 for simplicity + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 2, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=25), S(v=-25)], + ), # same score as index == 3 since None should read from "Other" + ( + [ + F( + q=QT.MULTIPLE_CHOICE, + v=[ + 1 / 3, + 1 - (np.e ** (0.25) / 3) - 1 / 3, + None, + np.e ** (0.25) / 3, + ], + ), + F( + q=QT.MULTIPLE_CHOICE, + v=[1 / 3, 1 / 3, None, 1 / 3], + ), + ], + {"resolution_bucket": 3, "question_type": QT.MULTIPLE_CHOICE}, + [S(v=25), S(v=-25)], + ), # chosen to have a score of 25 for simplicity # TODO: add tests with base forecasts different from forecasts ], ) diff --git a/tests/unit/test_utils/test_the_math/conftest.py b/tests/unit/test_utils/test_the_math/conftest.py index b048040bbf..ab4e99da78 100644 --- a/tests/unit/test_utils/test_the_math/conftest.py +++ b/tests/unit/test_utils/test_the_math/conftest.py @@ -1 +1,4 @@ -from tests.unit.test_questions.conftest import question_binary # noqa +from tests.unit.test_questions.conftest import ( # noqa + question_binary, + question_multiple_choice, +) diff --git a/tests/unit/test_utils/test_the_math/test_aggregations.py b/tests/unit/test_utils/test_the_math/test_aggregations.py index 73aaa5119e..911c9b4594 100644 --- a/tests/unit/test_utils/test_the_math/test_aggregations.py +++ b/tests/unit/test_utils/test_the_math/test_aggregations.py @@ -23,6 +23,12 @@ GoldMedalistsAggregation, JoinedBeforeDateAggregation, SingleAggregation, + compute_weighted_semi_standard_deviations, +) +from utils.typing import ( + ForecastValues, + ForecastsValues, + Weights, ) @@ -46,6 +52,64 @@ def test_summarize_array(array, max_size, expceted_array): class TestAggregations: + @pytest.mark.parametrize( + "forecasts_values, weights, expected", + [ + ( + [[0.5, 0.5]], + None, + ([0.0, 0.0], [0.0, 0.0]), + ), # Trivial + ( + [ + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ], + None, + ([0.0, 0.0], [0.0, 0.0]), + ), # 3 unwavaring forecasts + ( + [ + [0.2, 0.8], + [0.5, 0.5], + [0.8, 0.2], + ], + None, + ([0.3, 0.3], [0.3, 0.3]), + ), # 3 unwavaring forecasts + ( + [ + [0.6, 0.15, None, 0.25], + [0.6, 0.15, None, 0.25], + ], + None, + ([0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]), + ), # identical forecasts with placeholders + ( + [ + [0.4, 0.25, None, 0.35], + [0.6, 0.15, None, 0.25], + ], + None, + ([0.1, 0.05, 0.0, 0.05], [0.1, 0.05, 0.0, 0.05]), + ), # minorly different forecasts with placeholders + ], + ) + def test_compute_weighted_semi_standard_deviations( + self, + forecasts_values: ForecastsValues, + weights: Weights | None, + expected: tuple[ForecastValues, ForecastValues], + ): + result = compute_weighted_semi_standard_deviations(forecasts_values, weights) + rl, ru = result + el, eu = expected + for v, e in zip(rl, el): + np.testing.assert_approx_equal(v, e) + for v, e in zip(ru, eu): + np.testing.assert_approx_equal(v, e) + @pytest.mark.parametrize("aggregation_name", [Agg.method for Agg in AGGREGATIONS]) def test_aggregations_initialize( self, question_binary: Question, aggregation_name: str @@ -241,46 +305,120 @@ def test_aggregations_initialize( histogram=None, ), ), + # Multiple choice with placeholders + ( + {}, + ForecastSet( + forecasts_values=[ + [0.6, 0.15, None, 0.25], + [0.6, 0.25, None, 0.15], + ], + timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + forecaster_ids=[1, 2], + timesteps=[ + datetime(2022, 1, 1, tzinfo=dt_timezone.utc), + datetime(2023, 1, 1, tzinfo=dt_timezone.utc), + ], + ), + True, + False, + AggregateForecast( + start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + method=AggregationMethod.UNWEIGHTED, + forecast_values=[0.6, 0.20, None, 0.20], + interval_lower_bounds=[0.6, 0.15, None, 0.15], + centers=[0.6, 0.20, None, 0.20], + interval_upper_bounds=[0.6, 0.25, None, 0.25], + means=[0.6, 0.20, None, 0.20], + forecaster_count=2, + ), + ), + ( + {}, + ForecastSet( + forecasts_values=[ + [0.6, 0.15, None, 0.25], + [0.6, 0.25, None, 0.15], + [0.4, 0.35, None, 0.25], + ], + timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + forecaster_ids=[1, 2], + timesteps=[ + datetime(2022, 1, 1, tzinfo=dt_timezone.utc), + datetime(2023, 1, 1, tzinfo=dt_timezone.utc), + ], + ), + True, + False, + AggregateForecast( + start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + method=AggregationMethod.UNWEIGHTED, + forecast_values=[ + 0.5453965360072925, + 0.22730173199635367, + None, + 0.22730173199635367, + ], + interval_lower_bounds=[ + 0.3635976906715284, + 0.1363810391978122, + None, + 0.1363810391978122, + ], + centers=[ + 0.5453965360072925, + 0.22730173199635367, + None, + 0.22730173199635367, + ], + interval_upper_bounds=[ + 0.5453965360072925, + 0.3182224247948951, + None, + 0.22730173199635367, + ], + means=[ + 0.5333333333333333, + 0.25, + None, + 0.21666666666666667, + ], + forecaster_count=3, + ), + ), ], ) def test_UnweightedAggregation( self, question_binary: Question, + question_multiple_choice: Question, init_params: dict, forecast_set: ForecastSet, include_stats: bool, histogram: bool, expected: AggregateForecast, ): - aggregation = UnweightedAggregation(question=question_binary, **init_params) - new_aggregation = aggregation.calculate_aggregation_entry( + if len(forecast_set.forecasts_values[0]) == 2: + question = question_binary + else: + question = question_multiple_choice + + aggregation = UnweightedAggregation(question=question, **init_params) + new_aggregation: AggregateForecast = aggregation.calculate_aggregation_entry( forecast_set, include_stats, histogram ) - assert new_aggregation.start_time == expected.start_time - assert ( - new_aggregation.forecast_values == expected.forecast_values - ) or np.allclose(new_aggregation.forecast_values, expected.forecast_values) - assert new_aggregation.forecaster_count == expected.forecaster_count - assert ( - new_aggregation.interval_lower_bounds == expected.interval_lower_bounds - ) or np.allclose( - new_aggregation.interval_lower_bounds, expected.interval_lower_bounds - ) - assert (new_aggregation.centers == expected.centers) or np.allclose( - new_aggregation.centers, expected.centers - ) - assert ( - new_aggregation.interval_upper_bounds == expected.interval_upper_bounds - ) or np.allclose( - new_aggregation.interval_upper_bounds, expected.interval_upper_bounds - ) - assert (new_aggregation.means == expected.means) or np.allclose( - new_aggregation.means, expected.means - ) - assert (new_aggregation.histogram == expected.histogram) or np.allclose( - new_aggregation.histogram, expected.histogram - ) + for r, e in [ + (new_aggregation.forecast_values, expected.forecast_values), + (new_aggregation.interval_lower_bounds, expected.interval_lower_bounds), + (new_aggregation.centers, expected.centers), + (new_aggregation.interval_upper_bounds, expected.interval_upper_bounds), + (new_aggregation.means, expected.means), + (new_aggregation.histogram, expected.histogram), + ]: + r = np.where(np.equal(r, None), np.nan, r).astype(float) + e = np.where(np.equal(e, None), np.nan, e).astype(float) + np.testing.assert_allclose(r, e, equal_nan=True) @pytest.mark.parametrize( "init_params, forecast_set, include_stats, histogram, expected", @@ -468,20 +606,52 @@ def test_UnweightedAggregation( histogram=None, ), ), + # Multiple choice with placeholders + ( + {}, + ForecastSet( + forecasts_values=[ + [0.6, 0.15, None, 0.25], + [0.6, 0.25, None, 0.15], + ], + timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + forecaster_ids=[1, 2], + timesteps=[ + datetime(2022, 1, 1, tzinfo=dt_timezone.utc), + datetime(2023, 1, 1, tzinfo=dt_timezone.utc), + ], + ), + True, + False, + AggregateForecast( + start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc), + method=AggregationMethod.UNWEIGHTED, + forecast_values=[0.6, 0.20, None, 0.20], + interval_lower_bounds=[0.6, 0.15, None, 0.15], + centers=[0.6, 0.20, None, 0.20], + interval_upper_bounds=[0.6, 0.25, None, 0.25], + means=[0.6, 0.20, None, 0.20], + forecaster_count=2, + ), + ), ], ) def test_RecencyWeightedAggregation( self, question_binary: Question, + question_multiple_choice: Question, init_params: dict, forecast_set: ForecastSet, include_stats: bool, histogram: bool, expected: AggregateForecast, ): - aggregation = RecencyWeightedAggregation( - question=question_binary, **init_params - ) + if len(forecast_set.forecasts_values[0]) == 2: + question = question_binary + else: + question = question_multiple_choice + + aggregation = RecencyWeightedAggregation(question=question, **init_params) new_aggregation = aggregation.calculate_aggregation_entry( forecast_set, include_stats, histogram ) diff --git a/tests/unit/test_utils/test_the_math/test_formulas.py b/tests/unit/test_utils/test_the_math/test_formulas.py index 54f78dd357..30bb3d3e13 100644 --- a/tests/unit/test_utils/test_the_math/test_formulas.py +++ b/tests/unit/test_utils/test_the_math/test_formulas.py @@ -15,7 +15,12 @@ class TestFormulas: binary_details = {"type": Question.QuestionType.BINARY} multiple_choice_details = { "type": Question.QuestionType.MULTIPLE_CHOICE, - "options": ["A", "B", "C"], + "options": ["a", "c", "Other"], + "options_history": [ + (0, ["a", "b", "Other"]), + (100, ["a", "Other"]), + (200, ["a", "c", "Other"]), + ], } numeric_details = { "type": Question.QuestionType.NUMERIC, @@ -57,8 +62,10 @@ class TestFormulas: ("", binary_details, None), (None, binary_details, None), # Multiple choice questions - ("A", multiple_choice_details, 0), - ("C", multiple_choice_details, 2), + ("a", multiple_choice_details, 0), + ("b", multiple_choice_details, 1), + ("c", multiple_choice_details, 2), + ("Other", multiple_choice_details, 3), # Numeric questions ("below_lower_bound", numeric_details, 0), ("-2", numeric_details, 0), diff --git a/tests/unit/test_utils/test_the_math/test_measures.py b/tests/unit/test_utils/test_the_math/test_measures.py index b5ee3c8356..ab2273d2f8 100644 --- a/tests/unit/test_utils/test_the_math/test_measures.py +++ b/tests/unit/test_utils/test_the_math/test_measures.py @@ -56,14 +56,26 @@ ( [ [0.33, 0.33, 0.34], - [0.0, 0.5, 0.5], + [0.01, 0.49, 0.5], [0.4, 0.2, 0.4], [0.2, 0.6, 0.2], ], [0.1, 0.2, 0.3, 0.4], [50.0], - [[0.2, 0.5, 0.37]], + [[0.2, 0.49, 0.37]], + ), + ( + [ + [0.33, 0.33, None, 0.34], + [0.01, 0.49, None, 0.5], + [0.4, 0.2, None, 0.4], + [0.2, 0.6, None, 0.2], + ], + [0.1, 0.2, 0.3, 0.4], + [50.0], + [[0.2, 0.49, None, 0.37]], ), + # multiple choice options with placeholder values ], ) def test_weighted_percentile_2d(values, weights, percentiles, expected_result): @@ -73,7 +85,11 @@ def test_weighted_percentile_2d(values, weights, percentiles, expected_result): result = weighted_percentile_2d( values=values, weights=weights, percentiles=percentiles ) - np.testing.assert_allclose(result, expected_result) + result = np.where(np.equal(result, None), np.nan, result).astype(float) + expected_result = np.where( + np.equal(expected_result, None), np.nan, expected_result + ).astype(float) + np.testing.assert_allclose(result, expected_result, equal_nan=True) if weights is None and [percentiles] == [50.0]: # should behave like np.median numpy_medians = np.median(values, axis=0) np.testing.assert_allclose(result, [numpy_medians]) @@ -95,6 +111,7 @@ def test_percent_point_function(cdf, percentiles, expected_result): @pytest.mark.parametrize( "p1, p2, question, expected_result", [ + # binary ( [0.5, 0.5], [0.5, 0.5], @@ -107,6 +124,7 @@ def test_percent_point_function(cdf, percentiles, expected_result): Question(type="binary"), sum([-0.1 * np.log2(0.5 / 0.6), 0.1 * np.log2(0.5 / 0.4)]), # 0.05849625 ), + # multiple choice ( [0.5, 0.5], [0.5, 0.5], @@ -138,6 +156,54 @@ def test_percent_point_function(cdf, percentiles, expected_result): ] ), # 1.3169925 ), + ( + [0.2, 0.3, 0.5], + [0.2, 0.2, 0.6], + Question(type="multiple_choice"), + sum( + [ + 0, + (0.3 - 0.2) * np.log2(0.3 / 0.2), + (0.5 - 0.6) * np.log2(0.5 / 0.6), + ] + ), # 0.0847996 + ), + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.3, None, 0.5], + Question(type="multiple_choice"), + 0.0, + ), # deal with Nones happily + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.3, 0.1, 0.4], + Question(type="multiple_choice"), + 0.0, + ), # no difference across adding an option + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.2, 0.1, 0.5], + Question(type="multiple_choice"), + sum( + [ + 0, + (0.3 - 0.2) * np.log2(0.3 / 0.2), + (0.5 - 0.6) * np.log2(0.5 / 0.6), + ] + ), # 0.0847996 + ), # difference across adding an option + ( + [0.2, 0.3, None, 0.5], + [0.1, None, 0.7, 0.2], + Question(type="multiple_choice"), + sum( + [ + (0.2 - 0.1) * np.log2(0.2 / 0.1), + (0.8 - 0.9) * np.log2(0.8 / 0.9), + ] + ), # 0.1169925 + ), # difference across removing and adding options + # continuous ( [0.01, 0.5, 0.99], [0.01, 0.5, 0.99], @@ -214,6 +280,7 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result): @pytest.mark.parametrize( "p1, p2, question, expected_result", [ + # binary ( [0.5, 0.5], [0.5, 0.5], @@ -230,6 +297,7 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result): (-0.1, (2 / 3) / 1), ], ), + # multiple choice ( [0.5, 0.5], [0.5, 0.5], @@ -270,6 +338,61 @@ def test_prediction_difference_for_sorting(p1, p2, question, expected_result): (-0.3, (1 / 9) / (4 / 6)), ], ), + ( + [0.2, 0.3, 0.5], + [0.2, 0.2, 0.6], + Question(type="multiple_choice"), + [ + (0.0, (2 / 8) / (2 / 8)), + (-0.1, (2 / 8) / (3 / 7)), + (0.1, (6 / 4) / (5 / 5)), + ], + ), + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.3, None, 0.5], + Question(type="multiple_choice"), + [ + (0.0, (2 / 8) / (2 / 8)), + (0.0, (3 / 7) / (3 / 7)), + (0.0, 1.0), + (0.0, (5 / 5) / (5 / 5)), + ], + ), # deal with 0.0s happily + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.3, 0.1, 0.4], + Question(type="multiple_choice"), + [ + (0.0, (2 / 8) / (2 / 8)), + (0.0, (3 / 7) / (3 / 7)), + (0.0, 1.0), + (0.0, (5 / 5) / (5 / 5)), + ], + ), # no difference across adding an option + ( + [0.2, 0.3, None, 0.5], + [0.2, 0.2, 0.1, 0.5], + Question(type="multiple_choice"), + [ + (0.0, (2 / 8) / (2 / 8)), + (-0.1, (2 / 8) / (3 / 7)), + (0.0, 1.0), + (0.1, (6 / 4) / (5 / 5)), + ], + ), # difference across adding an option + ( + [0.2, 0.3, None, 0.5], + [0.1, None, 0.7, 0.2], + Question(type="multiple_choice"), + [ + (-0.1, (1 / 9) / (2 / 8)), + (0.0, 1.0), + (0.0, 1.0), + (0.1, (9 / 1) / (8 / 2)), + ], + ), # difference across removing and adding options + # continuous ( [0.0, 0.5, 1.0], [0.0, 0.5, 1.0], diff --git a/utils/csv_utils.py b/utils/csv_utils.py index d92ed9f775..14cccf45ef 100644 --- a/utils/csv_utils.py +++ b/utils/csv_utils.py @@ -17,6 +17,7 @@ Forecast, QUESTION_CONTINUOUS_TYPES, ) +from questions.services.multiple_choice_handlers import get_all_options_from_history from questions.types import AggregationMethod from scoring.models import Score, ArchivedScore from users.models import User @@ -368,7 +369,9 @@ def generate_data( + "**`Default Project ID`** - the id of the default project for the Post.\n" + "**`Label`** - for a group question, this is the sub-question object.\n" + "**`Question Type`** - the type of the question. Binary, Multiple Choice, Numeric, Discrete, or Date.\n" - + "**`MC Options`** - the options for a multiple choice question, if applicable.\n" + + "**`MC Options (Current)`** - the current options for a multiple choice question, if applicable.\n" + + "**`MC Options (All)`** - the options for a multiple choice question across all time, if applicable.\n" + + "**`MC Options History`** - the history of options over time. Each entry is a isoformat time and a record of what the options were at that time.\n" + "**`Lower Bound`** - the lower bound of the forecasting range for a continuous question.\n" + "**`Open Lower Bound`** - whether the lower bound is open.\n" + "**`Upper Bound`** - the upper bound of the forecasting range for a continuous question.\n" @@ -397,7 +400,9 @@ def generate_data( "Default Project ID", "Label", "Question Type", - "MC Options", + "MC Options (Current)", + "MC Options (All)", + "MC Options History", "Lower Bound", "Open Lower Bound", "Upper Bound", @@ -446,7 +451,13 @@ def format_value(val): post.default_project_id, question.label, question.type, - question.options or None, + question.options, + ( + get_all_options_from_history(question.options_history) + if question.options_history + else None + ), + question.options_history or None, format_value(question.range_min), question.open_lower_bound, format_value(question.range_max), @@ -486,7 +497,7 @@ def format_value(val): + "**`End Time`** - the time when the forecast ends. If not populated, the forecast is still active. Note that this can be set in the future indicating an expiring forecast.\n" + "**`Forecaster Count`** - if this is an aggregate forecast, how many forecasts contribute to it.\n" + "**`Probability Yes`** - the probability of the binary question resolving to 'Yes'\n" - + "**`Probability Yes Per Category`** - a list of probabilities corresponding to each option for a multiple choice question. Cross-reference 'MC Options' in `question_data.csv`.\n" + + "**`Probability Yes Per Category`** - a list of probabilities corresponding to each option for a multiple choice question. Cross-reference 'MC Options (All)' in `question_data.csv`. Note that a Multiple Choice forecast will have None in places where the corresponding option wasn't available for forecast at the time.\n" + "**`Continuous CDF`** - the value of the CDF (cumulative distribution function) at each of the locations in the continuous range for a continuous question. Cross-reference 'Continuous Range' in `question_data.csv`.\n" + "**`Probability Below Lower Bound`** - the probability of the question resolving below the lower bound for a continuous question.\n" + "**`Probability Above Upper Bound`** - the probability of the question resolving above the upper bound for a continuous question.\n" diff --git a/utils/the_math/aggregations.py b/utils/the_math/aggregations.py index c7b53f46a7..3b0c57f3b7 100644 --- a/utils/the_math/aggregations.py +++ b/utils/the_math/aggregations.py @@ -489,6 +489,9 @@ def get_range_values( forecasts_values, weights, [25.0, 50.0, 75.0] ) centers_array = np.array(centers) + centers_array[np.equal(centers_array, 0.0) | (centers_array == 0.0)] = ( + 1.0 # avoid divide by zero + ) normalized_centers = np.array(aggregation_forecast_values) normalized_lowers = np.array(lowers) normalized_lowers[non_nones] = ( @@ -498,7 +501,7 @@ def get_range_values( ) normalized_uppers = np.array(uppers) normalized_uppers[non_nones] = ( - normalized_lowers[non_nones] + normalized_uppers[non_nones] * normalized_centers[non_nones] / centers_array[non_nones] ) @@ -641,9 +644,18 @@ def calculate_aggregation_entry( Question.QuestionType.BINARY, Question.QuestionType.MULTIPLE_CHOICE, ]: - aggregation.means = np.average( - forecast_set.forecasts_values, weights=weights, axis=0 - ).tolist() + forecasts_values = np.array(forecast_set.forecasts_values) + nones = ( + np.equal(forecasts_values[0], None) + if forecasts_values.size + else np.array([]) + ) + forecasts_values[:, nones] = np.nan + means = np.average(forecasts_values, weights=weights, axis=0).astype( + object + ) + means[np.isnan(means.astype(float))] = None + aggregation.means = means.tolist() if histogram and self.question.type in [ Question.QuestionType.BINARY, diff --git a/utils/the_math/formulas.py b/utils/the_math/formulas.py index 999444794c..d582039269 100644 --- a/utils/the_math/formulas.py +++ b/utils/the_math/formulas.py @@ -5,6 +5,7 @@ from questions.constants import UnsuccessfulResolutionType from questions.models import Question +from questions.services.multiple_choice_handlers import get_all_options_from_history from utils.typing import ForecastValues logger = logging.getLogger(__name__) @@ -33,7 +34,8 @@ def string_location_to_scaled_location( if question.type == Question.QuestionType.BINARY: return 1.0 if string_location == "yes" else 0.0 if question.type == Question.QuestionType.MULTIPLE_CHOICE: - return float(question.options.index(string_location)) + list_of_all_options = get_all_options_from_history(question.options_history) + return float(list_of_all_options.index(string_location)) # continuous if string_location == "below_lower_bound": return question.range_min - 1.0 diff --git a/utils/the_math/measures.py b/utils/the_math/measures.py index e20bd381be..7edce08712 100644 --- a/utils/the_math/measures.py +++ b/utils/the_math/measures.py @@ -17,16 +17,17 @@ def weighted_percentile_2d( percentiles: Percentiles = None, ) -> Percentiles: values = np.array(values) + sorted_values = values.copy() # avoid side effects + # replace None with np.nan for calculations (return to None at the end) + sorted_values[np.equal(sorted_values, None)] = np.nan + if weights is None: ordered_weights = np.ones_like(values) else: weights = np.array(weights) - ordered_weights = weights[values.argsort(axis=0)] + ordered_weights = weights[sorted_values.argsort(axis=0)] percentiles = np.array(percentiles or [50.0]) - sorted_values = values.copy() # avoid side effects - # replace None with -1.0 for calculations (return to None at the end) - sorted_values[np.equal(sorted_values, None)] = -1.0 sorted_values.sort(axis=0) # get the normalized cumulative weights @@ -52,10 +53,10 @@ def weighted_percentile_2d( + sorted_values[right_indexes, column_indicies] ) ) - # replace -1.0 back to None + # replace np.nan back to None weighted_percentiles = np.array(weighted_percentiles) weighted_percentiles = np.where( - weighted_percentiles == -1.0, None, weighted_percentiles + np.isnan(weighted_percentiles.astype(float)), None, weighted_percentiles ) return weighted_percentiles.tolist() @@ -104,10 +105,22 @@ def prediction_difference_for_sorting( """for binary and multiple choice, takes pmfs for continuous takes cdfs""" p1, p2 = np.array(p1), np.array(p2) - p1[np.equal(p1, None)] = -1.0 # replace None with -1.0 for calculations - p2[np.equal(p2, None)] = -1.0 # replace None with -1.0 for calculations # Uses Jeffrey's Divergence - if question_type in ["binary", "multiple_choice"]: + if question_type == Question.QuestionType.MULTIPLE_CHOICE: + # cover for Nones + p1_nones = np.equal(p1, None) + p2_nones = np.equal(p2, None) + never_nones = np.logical_not(p1_nones | p2_nones) + p1_new = p1[never_nones] + p2_new = p2[never_nones] + p1_new[-1] += sum(p1[~p1_nones & p2_nones]) + p2_new[-1] += sum(p2[~p2_nones & p1_nones]) + p1 = p1_new + p2 = p2_new + if question_type in [ + Question.QuestionType.BINARY, + Question.QuestionType.MULTIPLE_CHOICE, + ]: return sum([(p - q) * np.log2(p / q) for p, q in zip(p1, p2)]) cdf1 = np.array([1 - np.array(p1), p1]) cdf2 = np.array([1 - np.array(p2), p2]) @@ -123,14 +136,22 @@ def prediction_difference_for_display( """for binary and multiple choice, takes pmfs for continuous takes cdfs""" p1, p2 = np.array(p1), np.array(p2) - p1[np.equal(p1, None)] = -1.0 # replace None with -1.0 for calculations - p2[np.equal(p2, None)] = -1.0 # replace None with -1.0 for calculations if question.type == "binary": # single-item list of (pred diff, ratio of odds) return [(p2[1] - p1[1], (p2[1] / (1 - p2[1])) / (p1[1] / (1 - p1[1])))] elif question.type == "multiple_choice": # list of (pred diff, ratio of odds) - return [(q - p, (q / (1 - q)) / (p / (1 - p))) for p, q in zip(p1, p2)] + for p, q in zip(p1[:-1], p2[:-1]): + if p is None or q is None: + p1[-1] = (p1[-1] or 0.0) + (p or 0.0) + p2[-1] = (p2[-1] or 0.0) + (q or 0.0) + arr = [] + for p, q in zip(p1, p2): + if p is None or q is None: + arr.append((0.0, 1.0)) + else: + arr.append((q - p, (q / (1 - q)) / (p / (1 - p)))) + return arr # total earth mover's distance, assymmetric earth mover's distance x_locations = unscaled_location_to_scaled_location( np.linspace(0, 1, len(p1)), question