@@ -67,19 +67,26 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
6767 assert mem_and_table_dims [i ] == 0
6868 log_memory = mem_and_table_dims [0 ]
6969
70- table_dims = mem_and_table_dims + 1
71- log_cycles = table_dims [EXECUTION_TABLE_INDEX ]
72- assert log_cycles <= log_memory
70+ table_log_heights = mem_and_table_dims + 1
71+ log_n_cycles = table_log_heights [EXECUTION_TABLE_INDEX ]
72+ assert log_n_cycles <= log_memory
7373
74+ log_bytecode = log2_ceil (GUEST_BYTECODE_LEN )
75+ log_bytecode_padded = maximum (log_bytecode , log_n_cycles )
76+
77+ table_heights = Array (N_TABLES )
7478 for i in unroll (0 , N_TABLES ):
75- n_vars_for_table = table_dims [i ]
76- assert n_vars_for_table <= log_cycles
77- assert MIN_LOG_N_ROWS_PER_TABLE <= n_vars_for_table
78- assert n_vars_for_table <= MAX_LOG_N_ROWS_PER_TABLE [i ]
79+ table_log_height = table_log_heights [i ]
80+ table_heights [i ] = powers_of_two (table_log_height )
81+ assert table_log_height <= log_n_cycles
82+ assert MIN_LOG_N_ROWS_PER_TABLE <= table_log_height
83+ assert table_log_height <= MAX_LOG_N_ROWS_PER_TABLE [i ]
7984 assert MIN_LOG_MEMORY_SIZE <= log_memory
8085 assert log_memory <= MAX_LOG_MEMORY_SIZE
8186 assert log_memory <= GUEST_BYTECODE_LEN
8287
88+ stacked_n_vars = compute_stacked_n_vars (log_memory , log_bytecode_padded , table_heights )
89+
8390 fs , whir_base_root , whir_base_ood_points , whir_base_ood_evals = parse_whir_commitment_const (fs , WHIR_NUM_OOD_COMMIT )
8491
8592 fs , logup_c = fs_sample_ef (fs )
@@ -107,9 +114,6 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
107114
108115 offset : Mut = powers_of_two (log_memory )
109116
110- log_bytecode = log2_ceil (GUEST_BYTECODE_LEN )
111- log_n_cycles = table_dims [EXECUTION_TABLE_INDEX ]
112- log_bytecode_padded = maximum (log_bytecode , log_n_cycles )
113117 bytecode_and_acc_point = point_gkr + (N_VARS_LOGUP_GKR - log_bytecode ) * DIM
114118 bytecode_multilinear_location_prefix = multilinear_location_prefix (
115119 offset / 2 ** log2_ceil (GUEST_BYTECODE_LEN ), N_VARS_LOGUP_GKR - log_bytecode , point_gkr
@@ -185,8 +189,8 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
185189 for table_index in unroll (0 , N_TABLES ):
186190 # I] Bus (data flow between tables)
187191
188- log_n_rows = table_dims [table_index ]
189- n_rows = powers_of_two ( log_n_rows )
192+ log_n_rows = table_log_heights [table_index ]
193+ n_rows = table_heights [ table_index ]
190194 inner_point = point_gkr + (N_VARS_LOGUP_GKR - log_n_rows ) * DIM
191195 pcs_points [table_index ].push (inner_point )
192196
@@ -309,7 +313,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
309313 air_alpha_powers = powers_const (air_alpha , MAX_NUM_AIR_CONSTRAINTS + 1 )
310314
311315 for table_index in unroll (0 , N_TABLES ):
312- log_n_rows = table_dims [table_index ]
316+ log_n_rows = table_log_heights [table_index ]
313317 bus_numerator_value = bus_numerators_values [table_index ]
314318 bus_denominator_value = bus_denominators_values [table_index ]
315319 total_num_cols = NUM_COLS_F_AIR [table_index ] + DIM * NUM_COLS_EF_AIR [table_index ]
@@ -479,6 +483,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
479483 end_sum : Mut
480484 fs , folding_randomness_global , s , final_value , end_sum = whir_open (
481485 fs ,
486+ stacked_n_vars ,
482487 whir_base_root ,
483488 whir_base_ood_points ,
484489 combination_randomness_powers ,
@@ -488,31 +493,31 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
488493 curr_randomness = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM
489494
490495 eq_memory_and_acc_point = eq_mle_extension (
491- folding_randomness_global + (WHIR_N_VARS - log_memory ) * DIM ,
496+ folding_randomness_global + (stacked_n_vars - log_memory ) * DIM ,
492497 memory_and_acc_point ,
493498 log_memory ,
494499 )
495- prefix_memory = multilinear_location_prefix (0 , WHIR_N_VARS - log_memory , folding_randomness_global )
500+ prefix_memory = multilinear_location_prefix (0 , stacked_n_vars - log_memory , folding_randomness_global )
496501 s = add_extension_ret (
497502 s ,
498503 mul_extension_ret (mul_extension_ret (curr_randomness , prefix_memory ), eq_memory_and_acc_point ),
499504 )
500505 curr_randomness += DIM
501506
502- prefix_acc_memory = multilinear_location_prefix (1 , WHIR_N_VARS - log_memory , folding_randomness_global )
507+ prefix_acc_memory = multilinear_location_prefix (1 , stacked_n_vars - log_memory , folding_randomness_global )
503508 s = add_extension_ret (
504509 s ,
505510 mul_extension_ret (mul_extension_ret (curr_randomness , prefix_acc_memory ), eq_memory_and_acc_point ),
506511 )
507512 curr_randomness += DIM
508513
509514 eq_pub_mem = eq_mle_extension (
510- folding_randomness_global + (WHIR_N_VARS - inner_public_memory_log_size ) * DIM ,
515+ folding_randomness_global + (stacked_n_vars - inner_public_memory_log_size ) * DIM ,
511516 public_memory_random_point ,
512517 inner_public_memory_log_size ,
513518 )
514519 prefix_pub_mem = multilinear_location_prefix (
515- 0 , WHIR_N_VARS - inner_public_memory_log_size , folding_randomness_global
520+ 0 , stacked_n_vars - inner_public_memory_log_size , folding_randomness_global
516521 )
517522 s = add_extension_ret (
518523 s ,
@@ -523,13 +528,13 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
523528 offset = powers_of_two (log_memory ) * 2 # memory and acc_memory
524529
525530 eq_bytecode_acc = eq_mle_extension (
526- folding_randomness_global + (WHIR_N_VARS - log2_ceil (GUEST_BYTECODE_LEN )) * DIM ,
531+ folding_randomness_global + (stacked_n_vars - log2_ceil (GUEST_BYTECODE_LEN )) * DIM ,
527532 bytecode_and_acc_point ,
528533 log2_ceil (GUEST_BYTECODE_LEN ),
529534 )
530535 prefix_bytecode_acc = multilinear_location_prefix (
531536 offset / 2 ** log2_ceil (GUEST_BYTECODE_LEN ),
532- WHIR_N_VARS - log2_ceil (GUEST_BYTECODE_LEN ),
537+ stacked_n_vars - log2_ceil (GUEST_BYTECODE_LEN ),
533538 folding_randomness_global ,
534539 )
535540 s = add_extension_ret (
@@ -541,36 +546,36 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
541546
542547 prefix_pc_start = multilinear_location_prefix (
543548 offset + COL_PC * powers_of_two (log_n_cycles ),
544- WHIR_N_VARS ,
549+ stacked_n_vars ,
545550 folding_randomness_global ,
546551 )
547552 s = add_extension_ret (s , mul_extension_ret (curr_randomness , prefix_pc_start ))
548553 curr_randomness += DIM
549554
550555 prefix_pc_end = multilinear_location_prefix (
551556 offset + (COL_PC + 1 ) * powers_of_two (log_n_cycles ) - 1 ,
552- WHIR_N_VARS ,
557+ stacked_n_vars ,
553558 folding_randomness_global ,
554559 )
555560 s = add_extension_ret (s , mul_extension_ret (curr_randomness , prefix_pc_end ))
556561 curr_randomness += DIM
557562
558563 for table_index in unroll (0 , N_TABLES ):
559- log_n_rows = table_dims [table_index ]
560- n_rows = powers_of_two ( log_n_rows )
564+ log_n_rows = table_log_heights [table_index ]
565+ n_rows = table_heights [ table_index ]
561566 total_num_cols = NUM_COLS_F_AIR [table_index ] + DIM * NUM_COLS_EF_AIR [table_index ]
562567 for i in unroll (0 , len (pcs_points [table_index ])):
563568 point = pcs_points [table_index ][i ]
564569 eq_factor = eq_mle_extension (
565570 point ,
566- folding_randomness_global + (WHIR_N_VARS - log_n_rows ) * DIM ,
571+ folding_randomness_global + (stacked_n_vars - log_n_rows ) * DIM ,
567572 log_n_rows ,
568573 )
569574 for j in unroll (0 , total_num_cols ):
570575 if len (pcs_values [table_index ][i ][j ]) == 1 :
571576 prefix = multilinear_location_prefix (
572577 offset / n_rows + j ,
573- WHIR_N_VARS - log_n_rows ,
578+ stacked_n_vars - log_n_rows ,
574579 folding_randomness_global ,
575580 )
576581 s = add_extension_ret (
@@ -675,6 +680,17 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den):
675680 return fs , postponed_point , new_claim_num , new_claim_den
676681
677682
683+ @inline
684+ def compute_stacked_n_vars (log_memory , log_bytecode_padded , tables_heights ):
685+ total : Mut = powers_of_two (log_memory + 1 ) # memory + acc_memory
686+ total += powers_of_two (log_bytecode_padded )
687+ for table_index in unroll (0 , N_TABLES ):
688+ n_rows = tables_heights [table_index ]
689+ total += n_rows * (NUM_COLS_F_AIR [table_index ] + DIM * NUM_COLS_EF_AIR [table_index ])
690+ debug_assert (30 - 24 < MIN_LOG_N_ROWS_PER_TABLE ) # cf log2_ceil
691+ return MIN_LOG_N_ROWS_PER_TABLE + log2_ceil_runtime (total / 2 ** MIN_LOG_N_ROWS_PER_TABLE )
692+
693+
678694def evaluate_air_constraints (table_index , inner_evals , air_alpha_powers , bus_beta , logup_alphas_eq_poly ):
679695 res : Imu
680696 debug_assert (table_index < 3 )
0 commit comments