Skip to content

Commit 6f63682

Browse files
author
hagan
committed
need to just fix seq_date_sql
1 parent d34b5a2 commit 6f63682

37 files changed

+197
-767
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
^test_script.R
2121
non_standard_calendar.R
2222
^vignettes/articles$
23+
^vignettes/\.quarto$

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ Suggests:
3333
fansi,
3434
crayon,
3535
knitr,
36-
contoso
36+
contoso,
37+
knitr
3738
Config/testthat/edition: 3
3839
VignetteBuilder: knitr
3940
Depends:

R/abc.R

Lines changed: 62 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -91,114 +91,89 @@ abc_fn <- function(x){
9191
# category_values <- c(.2,.5,.3)
9292
# fn="n"
9393

94-
# create summary table by the group
95-
96-
if(x@value@value_vec!="n"){
97-
98-
summary_dbi <- x@datum@data |>
94+
# 1. AGGREGATE DATA
95+
# We handle both count-based (n) and sum-based (.value) categorization
96+
if (x@value@value_vec != "n") {
97+
summary_dbi <- x@datum@data |>
9998
dplyr::summarize(
100-
!!x@value@new_column_name_vec:=sum(!!x@value@value_quo,na.rm=TRUE)
101-
,.groups="drop"
102-
) |>
103-
dbplyr::window_order(desc(!!x@value@new_column_name_quo))
104-
99+
!!x@value@new_column_name_vec := sum(!!x@value@value_quo, na.rm = TRUE),
100+
.groups = "drop"
101+
)
105102
} else {
106-
107-
summary_dbi <- x@datum@data |>
108-
dplyr::summarize(
109-
!!x@value@new_column_name_vec:=dplyr::n()
110-
,.groups="drop"
111-
) |>
112-
dbplyr::window_order(desc(!!x@value@new_column_name_quo))
113-
103+
summary_dbi <- x@datum@data |>
104+
dplyr::summarize(
105+
!!x@value@new_column_name_vec := dplyr::n(),
106+
.groups = "drop"
107+
)
114108
}
115109

116-
117-
# category_tbl <- abc_obj@category_values |>
118-
# stack() |>
119-
# tibble::as_tibble() |>
120-
# dplyr::rename(
121-
# category_value=values,category_name=ind
122-
# )
123-
124-
125-
## create summary stats table
126-
110+
# 2. CALCULATE CUMULATIVE STATS
111+
# Using window functions to prepare the 'cum_unit_prop' for bucket matching
127112
stats_dbi <- summary_dbi |>
113+
dbplyr::window_order(desc(!!x@value@new_column_name_quo)) |>
128114
dplyr::mutate(
129-
cum_sum=cumsum(!!x@value@new_column_name_quo)
130-
,prop_total=!!x@value@new_column_name_quo/max(cum_sum,na.rm=TRUE)
131-
,cum_prop_total=cumsum(prop_total)
132-
,row_id=dplyr::row_number()
133-
,max_row_id=max(row_id,na.rm=TRUE)
134-
,cum_unit_prop=row_id/max_row_id
115+
cum_sum = cumsum(!!x@value@new_column_name_quo),
116+
prop_total = !!x@value@new_column_name_quo / max(cum_sum, na.rm = TRUE),
117+
cum_prop_total = cumsum(prop_total),
118+
row_id = dplyr::row_number(),
119+
max_row_id = max(row_id, na.rm = TRUE),
120+
cum_unit_prop = row_id / max_row_id # This determines the ABC bucket
135121
)
136122

123+
# 3. PREPARE THE LOOKUP TABLE (The Optimization)
124+
# Instead of glue/paste SQL strings, we create a small temp table on the DB
125+
con <- dbplyr::remote_con(stats_dbi)
137126

138-
# assign names to category list
139-
names(x@category@category_values) <- x@category@category_names
140-
141-
## create sql scripts for category CTE----------
142-
143-
category_values_vec <- glue::glue("({x@category@category_values}")
144-
145-
category_names_vec <- glue::glue("'{names(x@category@category_values)}')")
146-
147-
148-
sql_values <- stringr::str_flatten_comma(paste0(category_values_vec,", ",category_names_vec))
149-
150-
sql_base <- "WITH my_cte (category_value, category_name) AS (
151-
VALUES "
152-
153-
sql_end <- ") select * from my_cte"
154-
155-
sql_category_dbi <- paste0(sql_base,sql_values,sql_end)
156-
157-
## grab connection from the summary tbl
127+
# Ensure category names exist (default to A, B, C... if empty)
128+
cat_names <- x@category@category_names %||% LETTERS[seq_along(x@category@category_values)]
158129

159-
con <- dbplyr::remote_con(stats_dbi)
130+
cat_lookup_df <- data.frame(
131+
category_value = x@category@category_values,
132+
category_name = cat_names
133+
)
160134

161-
## create category table to be used later
162-
category_dbi <- dplyr::tbl(con,dplyr::sql(sql_category_dbi))
135+
# Copy the tiny threshold table to the database
136+
category_dbi <- dplyr::copy_to(
137+
dest = con,
138+
df = cat_lookup_df,
139+
name = paste0("tmp_abc_", sample(1000:9999, 1)), # Random name to avoid collisions
140+
overwrite = TRUE,
141+
temporary = TRUE
142+
)
163143

164-
# join together stats table and category table and then filter to reduce duplicate matches
144+
# 4. PERFORM THE OPTIMIZED JOIN
145+
# Instead of an inequality join which is essentially a filtered Cartesian product:
146+
#
165147
out <- stats_dbi |>
166-
dplyr::left_join(
167-
category_dbi
168-
,by=dplyr::join_by(cum_unit_prop<=category_value)
169-
) |>
170-
dplyr::mutate(
171-
delta=category_value-cum_unit_prop
172-
) |>
173-
dplyr::mutate(
174-
row_id_rank=rank(delta)
175-
,.by=row_id
176-
) |>
177-
dplyr::filter(
178-
row_id_rank==1
148+
dplyr::cross_join(category_dbi) |>
149+
# Find all thresholds that are greater than or equal to our current position
150+
dplyr::filter(cum_unit_prop <= category_value) |>
151+
# Use a window function to pick the 'closest' (smallest) threshold
152+
dplyr::mutate(
153+
dist_rank = rank(category_value),
154+
.by = row_id
179155
) |>
180-
dplyr::select(-c(row_id_rank,delta))
181-
182-
## previous ------
183-
#
184-
# out <- stats_tbl |>
185-
# dplyr::left_join(
186-
# category_tbl
187-
# ,by=dplyr::join_by(dplyr::closest(cum_unit_prop<=category_value))
188-
# ) |>
189-
# arrange(category_name)
156+
dplyr::filter(dist_rank == 1) |>
157+
# Cleanup intermediate columns
158+
dplyr::select(
159+
-dist_rank,
160+
-category_value,
161+
-row_id,
162+
-max_row_id,
163+
-cum_unit_prop
164+
)
190165

191166
return(out)
192167

193168

194169
}
195170

171+
196172
#' @title Cohort Analysis
197173
#' @name cohort
198174
#' @param .data tibble or dbi object
199175
#' @param .date date column
200176
#' @param .value id column
201-
#' @param calendar_type clarify the calendar type; 'standard' or '554'
202177
#' @param period_label do you want period labels or the dates c(TRUE , FALSE)
203178
#' @param time_unit do you want summarize the date column to 'day', 'week', 'month','quarter' or 'year'
204179
#'
@@ -216,7 +191,7 @@ abc_fn <- function(x){
216191
#' @return segment object
217192
#' @export
218193
#'
219-
cohort <- function(.data,.date,.value,calendar_type,time_unit="month",period_label=FALSE){
194+
cohort <- function(.data,.date,.value,time_unit="month",period_label=FALSE){
220195

221196
## test data
222197

@@ -229,7 +204,7 @@ cohort <- function(.data,.date,.value,calendar_type,time_unit="month",period_lab
229204
x <- segment_cohort(
230205
datum= datum(
231206
.data
232-
,calendar_type = calendar_type
207+
,calendar_type = "standard"
233208
,date_vec = rlang::as_label(rlang::enquo(.date))
234209
)
235210
,fn = fn(
@@ -359,4 +334,5 @@ cohort_fn <- function(x){
359334

360335

361336

362-
utils::globalVariables(c("category","delta","row_id_rank","cohort_date","period_id","cohort_id","category_value"))
337+
338+
utils::globalVariables(c("category", "delta", "row_id_rank", "cohort_date", "period_id", "cohort_id", "category_value", "dist_rank"))

R/methods.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ S7::method(create_calendar,ti) <- function(x){
8686
start_date = start_date,
8787
end_date = x@datum@max_date,
8888
time_unit = x@time_unit@value,
89-
con = dbplyr::remote_con(x@datum@data)
89+
.con = dbplyr::remote_con(x@datum@data)
9090
)
9191

9292
# 5. Build the Scaffolding ------------------------------------------------

0 commit comments

Comments
 (0)