@@ -91,114 +91,89 @@ abc_fn <- function(x){
9191 # category_values <- c(.2,.5,.3)
9292 # fn="n"
9393
94- # create summary table by the group
95-
96- if (x @ value @ value_vec != " n" ){
97-
98- summary_dbi <- x @ datum @ data | >
94+ # 1. AGGREGATE DATA
95+ # We handle both count-based (n) and sum-based (.value) categorization
96+ if (x @ value @ value_vec != " n" ) {
97+ summary_dbi <- x @ datum @ data | >
9998 dplyr :: summarize(
100- !! x @ value @ new_column_name_vec : = sum(!! x @ value @ value_quo ,na.rm = TRUE )
101- ,.groups = " drop"
102- ) | >
103- dbplyr :: window_order(desc(!! x @ value @ new_column_name_quo ))
104-
99+ !! x @ value @ new_column_name_vec : = sum(!! x @ value @ value_quo , na.rm = TRUE ),
100+ .groups = " drop"
101+ )
105102 } else {
106-
107- summary_dbi <- x @ datum @ data | >
108- dplyr :: summarize(
109- !! x @ value @ new_column_name_vec : = dplyr :: n()
110- ,.groups = " drop"
111- ) | >
112- dbplyr :: window_order(desc(!! x @ value @ new_column_name_quo ))
113-
103+ summary_dbi <- x @ datum @ data | >
104+ dplyr :: summarize(
105+ !! x @ value @ new_column_name_vec : = dplyr :: n(),
106+ .groups = " drop"
107+ )
114108 }
115109
116-
117- # category_tbl <- abc_obj@category_values |>
118- # stack() |>
119- # tibble::as_tibble() |>
120- # dplyr::rename(
121- # category_value=values,category_name=ind
122- # )
123-
124-
125- # # create summary stats table
126-
110+ # 2. CALCULATE CUMULATIVE STATS
111+ # Using window functions to prepare the 'cum_unit_prop' for bucket matching
127112 stats_dbi <- summary_dbi | >
113+ dbplyr :: window_order(desc(!! x @ value @ new_column_name_quo )) | >
128114 dplyr :: mutate(
129- cum_sum = cumsum(!! x @ value @ new_column_name_quo )
130- , prop_total = !! x @ value @ new_column_name_quo / max(cum_sum ,na.rm = TRUE )
131- , cum_prop_total = cumsum(prop_total )
132- , row_id = dplyr :: row_number()
133- , max_row_id = max(row_id ,na.rm = TRUE )
134- , cum_unit_prop = row_id / max_row_id
115+ cum_sum = cumsum(!! x @ value @ new_column_name_quo ),
116+ prop_total = !! x @ value @ new_column_name_quo / max(cum_sum , na.rm = TRUE ),
117+ cum_prop_total = cumsum(prop_total ),
118+ row_id = dplyr :: row_number(),
119+ max_row_id = max(row_id , na.rm = TRUE ),
120+ cum_unit_prop = row_id / max_row_id # This determines the ABC bucket
135121 )
136122
123+ # 3. PREPARE THE LOOKUP TABLE (The Optimization)
124+ # Instead of glue/paste SQL strings, we create a small temp table on the DB
125+ con <- dbplyr :: remote_con(stats_dbi )
137126
138- # assign names to category list
139- names(x @ category @ category_values ) <- x @ category @ category_names
140-
141- # # create sql scripts for category CTE----------
142-
143- category_values_vec <- glue :: glue(" ({x@category@category_values}" )
144-
145- category_names_vec <- glue :: glue(" '{names(x@category@category_values)}')" )
146-
147-
148- sql_values <- stringr :: str_flatten_comma(paste0(category_values_vec ," , " ,category_names_vec ))
149-
150- sql_base <- " WITH my_cte (category_value, category_name) AS (
151- VALUES "
152-
153- sql_end <- " ) select * from my_cte"
154-
155- sql_category_dbi <- paste0(sql_base ,sql_values ,sql_end )
156-
157- # # grab connection from the summary tbl
127+ # Ensure category names exist (default to A, B, C... if empty)
128+ cat_names <- x @ category @ category_names %|| % LETTERS [seq_along(x @ category @ category_values )]
158129
159- con <- dbplyr :: remote_con(stats_dbi )
130+ cat_lookup_df <- data.frame (
131+ category_value = x @ category @ category_values ,
132+ category_name = cat_names
133+ )
160134
161- # # create category table to be used later
162- category_dbi <- dplyr :: tbl(con ,dplyr :: sql(sql_category_dbi ))
135+ # Copy the tiny threshold table to the database
136+ category_dbi <- dplyr :: copy_to(
137+ dest = con ,
138+ df = cat_lookup_df ,
139+ name = paste0(" tmp_abc_" , sample(1000 : 9999 , 1 )), # Random name to avoid collisions
140+ overwrite = TRUE ,
141+ temporary = TRUE
142+ )
163143
164- # join together stats table and category table and then filter to reduce duplicate matches
144+ # 4. PERFORM THE OPTIMIZED JOIN
145+ # Instead of an inequality join which is essentially a filtered Cartesian product:
146+ #
165147 out <- stats_dbi | >
166- dplyr :: left_join(
167- category_dbi
168- ,by = dplyr :: join_by(cum_unit_prop < = category_value )
169- ) | >
170- dplyr :: mutate(
171- delta = category_value - cum_unit_prop
172- ) | >
173- dplyr :: mutate(
174- row_id_rank = rank(delta )
175- ,.by = row_id
176- ) | >
177- dplyr :: filter(
178- row_id_rank == 1
148+ dplyr :: cross_join(category_dbi ) | >
149+ # Find all thresholds that are greater than or equal to our current position
150+ dplyr :: filter(cum_unit_prop < = category_value ) | >
151+ # Use a window function to pick the 'closest' (smallest) threshold
152+ dplyr :: mutate(
153+ dist_rank = rank(category_value ),
154+ .by = row_id
179155 ) | >
180- dplyr :: select(- c(row_id_rank ,delta ))
181-
182- # # previous ------
183- #
184- # out <- stats_tbl |>
185- # dplyr::left_join(
186- # category_tbl
187- # ,by=dplyr::join_by(dplyr::closest(cum_unit_prop<=category_value))
188- # ) |>
189- # arrange(category_name)
156+ dplyr :: filter(dist_rank == 1 ) | >
157+ # Cleanup intermediate columns
158+ dplyr :: select(
159+ - dist_rank ,
160+ - category_value ,
161+ - row_id ,
162+ - max_row_id ,
163+ - cum_unit_prop
164+ )
190165
191166 return (out )
192167
193168
194169}
195170
171+
196172# ' @title Cohort Analysis
197173# ' @name cohort
198174# ' @param .data tibble or dbi object
199175# ' @param .date date column
200176# ' @param .value id column
201- # ' @param calendar_type clarify the calendar type; 'standard' or '554'
202177# ' @param period_label do you want period labels or the dates c(TRUE , FALSE)
203178# ' @param time_unit do you want summarize the date column to 'day', 'week', 'month','quarter' or 'year'
204179# '
@@ -216,7 +191,7 @@ abc_fn <- function(x){
216191# ' @return segment object
217192# ' @export
218193# '
219- cohort <- function (.data ,.date ,.value ,calendar_type , time_unit = " month" ,period_label = FALSE ){
194+ cohort <- function (.data ,.date ,.value ,time_unit = " month" ,period_label = FALSE ){
220195
221196 # # test data
222197
@@ -229,7 +204,7 @@ cohort <- function(.data,.date,.value,calendar_type,time_unit="month",period_lab
229204 x <- segment_cohort(
230205 datum = datum(
231206 .data
232- ,calendar_type = calendar_type
207+ ,calendar_type = " standard "
233208 ,date_vec = rlang :: as_label(rlang :: enquo(.date ))
234209 )
235210 ,fn = fn(
@@ -359,4 +334,5 @@ cohort_fn <- function(x){
359334
360335
361336
362- utils :: globalVariables(c(" category" ," delta" ," row_id_rank" ," cohort_date" ," period_id" ," cohort_id" ," category_value" ))
337+
338+ utils :: globalVariables(c(" category" , " delta" , " row_id_rank" , " cohort_date" , " period_id" , " cohort_id" , " category_value" , " dist_rank" ))
0 commit comments