Skip to content

Commit 56a8bc1

Browse files
committed
don't hardcode WHIR_N_VARS in recursion program
1 parent 3d28f9a commit 56a8bc1

File tree

11 files changed

+105
-81
lines changed

11 files changed

+105
-81
lines changed

crates/lean_compiler/snark_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,7 @@ def match_range(value: int, *args):
122122
raise AssertionError(f"Value {value} not in any range")
123123

124124
def hint_decompose_bits_xmss(*args):
125-
_ = args
125+
_ = args
126+
127+
def hint_log2_ceil(n):
128+
return log2_ceil(n)

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2182,7 +2182,7 @@ fn simplify_lines(
21822182
"Custom hint {function_name} should not return values, at {location}"
21832183
));
21842184
}
2185-
if !hint.n_args_range().contains(&args.len()) {
2185+
if args.len() != hint.n_args() {
21862186
return Err(format!(
21872187
"Custom hint {function_name}: invalid number of arguments, at {location}"
21882188
));

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,3 @@ def main():
156156
"#;
157157
compile_and_run(&ProgramSource::Raw(program.to_string()), (&[], &[]), false);
158158
}
159-
160-
#[test]
161-
fn bug() {
162-
let program = r#"
163-
def main():
164-
three = double(1) + 1
165-
assert three == 3
166-
return
167-
168-
def double(a: Const):
169-
return a + a
170-
"#;
171-
compile_and_run(&ProgramSource::Raw(program.to_string()), (&[], &[]), false);
172-
}

crates/lean_prover/src/prove_execution.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub struct ExecutionProof {
1515
pub proof: Vec<F>,
1616
pub proof_size_fe: usize,
1717
pub exec_summary: String,
18-
pub first_whir_n_vars: usize,
18+
pub whir_n_vars: usize,
1919
}
2020

2121
pub fn prove_execution(
@@ -110,7 +110,7 @@ pub fn prove_execution(
110110
&bytecode_acc,
111111
&traces,
112112
);
113-
let first_whir_n_vars = stacked_pcs_witness.global_polynomial.by_ref().n_vars();
113+
let whir_n_vars = stacked_pcs_witness.global_polynomial.by_ref().n_vars();
114114

115115
// logup (GKR)
116116
let logup_c = prover_state.sample();
@@ -207,7 +207,7 @@ pub fn prove_execution(
207207
proof: prover_state.raw_proof(),
208208
proof_size_fe,
209209
exec_summary,
210-
first_whir_n_vars,
210+
whir_n_vars,
211211
}
212212
}
213213

crates/lean_vm/src/core/constants.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ pub const DIGEST_LEN: usize = 8;
77

88
/// Minimum and maximum memory size (as powers of two)
99
pub const MIN_LOG_MEMORY_SIZE: usize = 16;
10-
pub const MAX_LOG_MEMORY_SIZE: usize = 29;
10+
pub const MAX_LOG_MEMORY_SIZE: usize = 25;
1111

1212
/// Maximum memory size for VM runner (specific to this implementation)
1313
pub const MAX_RUNNER_MEMORY_SIZE: usize = 1 << 24;
1414

1515
/// Minimum and maximum number of rows per table (as powers of two), both inclusive
1616
pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution.
1717
pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [
18-
(Table::execution(), 25),
19-
(Table::dot_product(), 22),
20-
(Table::poseidon16(), 21),
18+
(Table::execution(), 24),
19+
(Table::dot_product(), 20),
20+
(Table::poseidon16(), 20),
2121
];
2222

2323
/// Starting program counter
@@ -63,7 +63,7 @@ mod tests {
6363
use multilinear_toolkit::prelude::PrimeField64;
6464
use p3_util::log2_ceil_u64;
6565

66-
use crate::{DIMENSION, F, MAX_LOG_N_ROWS_PER_TABLE, Table, TableT};
66+
use crate::{DIMENSION, F, MAX_LOG_MEMORY_SIZE, MAX_LOG_N_ROWS_PER_TABLE, Table, TableT};
6767

6868
/// CRITICAL FOUR SOUNDNESS: TODO tripple check
6969
#[test]
@@ -94,10 +94,10 @@ mod tests {
9494

9595
#[test]
9696
fn ensure_not_too_big_commitment_surface() {
97-
let mut max_surface: u64 = 0;
97+
let mut max_surface: u64 = 2 * (1 << MAX_LOG_MEMORY_SIZE) as u64; // memory and acc_memory
9898
for (table, max_log_n_rows) in MAX_LOG_N_ROWS_PER_TABLE {
9999
max_surface += (table.n_committed_columns() as u64) << (max_log_n_rows as u64);
100100
}
101-
assert!(max_surface < F::ORDER_U64.next_power_of_two() / 2);
101+
assert!(max_surface <= 1 << 29); // Maximum data we can commit via WHIR using an initial folding factor of 7
102102
}
103103
}

crates/lean_vm/src/isa/hint.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use multilinear_toolkit::prelude::*;
66
use std::fmt::Debug;
77
use std::fmt::{Display, Formatter};
88
use std::hash::Hash;
9-
use std::ops::Range;
109
use strum::IntoEnumIterator;
1110
use utils::{ToUsize, pretty_integer, to_big_endian_in_field, to_little_endian_in_field};
1211

@@ -77,6 +76,7 @@ pub enum CustomHint {
7776
DecomposeBitsXMSS,
7877
DecomposeBits,
7978
LessThan,
79+
Log2Ceil,
8080
}
8181

8282
impl CustomHint {
@@ -85,14 +85,16 @@ impl CustomHint {
8585
Self::DecomposeBitsXMSS => "hint_decompose_bits_xmss",
8686
Self::DecomposeBits => "hint_decompose_bits",
8787
Self::LessThan => "hint_less_than",
88+
Self::Log2Ceil => "hint_log2_ceil",
8889
}
8990
}
9091

91-
pub fn n_args_range(&self) -> Range<usize> {
92+
pub fn n_args(&self) -> usize {
9293
match self {
93-
Self::DecomposeBitsXMSS => 5..6,
94-
Self::DecomposeBits => 4..5,
95-
Self::LessThan => 3..4,
94+
Self::DecomposeBitsXMSS => 5,
95+
Self::DecomposeBits => 4,
96+
Self::LessThan => 3,
97+
Self::Log2Ceil => 2,
9698
}
9799
}
98100

@@ -145,6 +147,11 @@ impl CustomHint {
145147
let result = if a.to_usize() < b.to_usize() { F::ONE } else { F::ZERO };
146148
ctx.memory.set(res_ptr, result)?;
147149
}
150+
Self::Log2Ceil => {
151+
let n = args[0].read_value(ctx.memory, ctx.fp)?.to_usize();
152+
let res_ptr = args[1].memory_address(ctx.fp)?;
153+
ctx.memory.set(res_ptr, F::from_usize(log2_ceil_usize(n)))?;
154+
}
148155
}
149156
Ok(())
150157
}

crates/rec_aggregation/fiat_shamir.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def fs_grinding(fs, bits):
2424
new_fs[8] = transcript_ptr + 8
2525

2626
sampled = new_fs[0]
27-
_, sampled_low_bits_value = checked_decompose_bits(sampled, bits)
27+
_, partial_sums_24 = checked_decompose_bits(sampled)
28+
sampled_low_bits_value = partial_sums_24[bits - 1]
2829
assert sampled_low_bits_value == 0
2930

3031
return new_fs
@@ -107,22 +108,22 @@ def fs_print_state(fs_state):
107108
return
108109

109110

110-
def sample_bits_const(fs, n_samples: Const, K):
111+
def sample_bits_const(fs, n_samples: Const):
111112
# return the updated fiat-shamir, and a pointer to n pointers, each pointing to 31 (boolean) field elements,
112113
sampled_bits = Array(n_samples)
113114
n_chunks = div_ceil(n_samples, 8)
114115
new_fs, sampled = fs_sample_chunks(fs, n_chunks)
115116
for i in unroll(0, n_samples):
116-
bits, _ = checked_decompose_bits(sampled[i], K)
117+
bits, _ = checked_decompose_bits(sampled[i])
117118
sampled_bits[i] = bits
118119
return new_fs, sampled_bits
119120

120121

121-
def sample_bits_dynamic(fs_state, n_samples, K):
122+
def sample_bits_dynamic(fs_state, n_samples):
122123
new_fs_state: Imu
123124
sampled_bits: Imu
124125
for r in unroll(0, WHIR_N_ROUNDS + 1):
125126
if n_samples == WHIR_NUM_QUERIES[r]:
126-
new_fs_state, sampled_bits = sample_bits_const(fs_state, WHIR_NUM_QUERIES[r], K)
127+
new_fs_state, sampled_bits = sample_bits_const(fs_state, WHIR_NUM_QUERIES[r])
127128
return new_fs_state, sampled_bits
128129
assert False, "sample_bits_dynamic called with unsupported n_samples"

crates/rec_aggregation/recursion.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,26 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
6767
assert mem_and_table_dims[i] == 0
6868
log_memory = mem_and_table_dims[0]
6969

70-
table_dims = mem_and_table_dims + 1
71-
log_cycles = table_dims[EXECUTION_TABLE_INDEX]
72-
assert log_cycles <= log_memory
70+
table_log_heights = mem_and_table_dims + 1
71+
log_n_cycles = table_log_heights[EXECUTION_TABLE_INDEX]
72+
assert log_n_cycles <= log_memory
7373

74+
log_bytecode = log2_ceil(GUEST_BYTECODE_LEN)
75+
log_bytecode_padded = maximum(log_bytecode, log_n_cycles)
76+
77+
table_heights = Array(N_TABLES)
7478
for i in unroll(0, N_TABLES):
75-
n_vars_for_table = table_dims[i]
76-
assert n_vars_for_table <= log_cycles
77-
assert MIN_LOG_N_ROWS_PER_TABLE <= n_vars_for_table
78-
assert n_vars_for_table <= MAX_LOG_N_ROWS_PER_TABLE[i]
79+
table_log_height = table_log_heights[i]
80+
table_heights[i] = powers_of_two(table_log_height)
81+
assert table_log_height <= log_n_cycles
82+
assert MIN_LOG_N_ROWS_PER_TABLE <= table_log_height
83+
assert table_log_height <= MAX_LOG_N_ROWS_PER_TABLE[i]
7984
assert MIN_LOG_MEMORY_SIZE <= log_memory
8085
assert log_memory <= MAX_LOG_MEMORY_SIZE
8186
assert log_memory <= GUEST_BYTECODE_LEN
8287

88+
stacked_n_vars = compute_stacked_n_vars(log_memory, log_bytecode_padded, table_heights)
89+
8390
fs, whir_base_root, whir_base_ood_points, whir_base_ood_evals = parse_whir_commitment_const(fs, WHIR_NUM_OOD_COMMIT)
8491

8592
fs, logup_c = fs_sample_ef(fs)
@@ -107,9 +114,6 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
107114

108115
offset: Mut = powers_of_two(log_memory)
109116

110-
log_bytecode = log2_ceil(GUEST_BYTECODE_LEN)
111-
log_n_cycles = table_dims[EXECUTION_TABLE_INDEX]
112-
log_bytecode_padded = maximum(log_bytecode, log_n_cycles)
113117
bytecode_and_acc_point = point_gkr + (N_VARS_LOGUP_GKR - log_bytecode) * DIM
114118
bytecode_multilinear_location_prefix = multilinear_location_prefix(
115119
offset / 2 ** log2_ceil(GUEST_BYTECODE_LEN), N_VARS_LOGUP_GKR - log_bytecode, point_gkr
@@ -185,8 +189,8 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
185189
for table_index in unroll(0, N_TABLES):
186190
# I] Bus (data flow between tables)
187191

188-
log_n_rows = table_dims[table_index]
189-
n_rows = powers_of_two(log_n_rows)
192+
log_n_rows = table_log_heights[table_index]
193+
n_rows = table_heights[table_index]
190194
inner_point = point_gkr + (N_VARS_LOGUP_GKR - log_n_rows) * DIM
191195
pcs_points[table_index].push(inner_point)
192196

@@ -309,7 +313,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
309313
air_alpha_powers = powers_const(air_alpha, MAX_NUM_AIR_CONSTRAINTS + 1)
310314

311315
for table_index in unroll(0, N_TABLES):
312-
log_n_rows = table_dims[table_index]
316+
log_n_rows = table_log_heights[table_index]
313317
bus_numerator_value = bus_numerators_values[table_index]
314318
bus_denominator_value = bus_denominators_values[table_index]
315319
total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index]
@@ -479,6 +483,7 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
479483
end_sum: Mut
480484
fs, folding_randomness_global, s, final_value, end_sum = whir_open(
481485
fs,
486+
stacked_n_vars,
482487
whir_base_root,
483488
whir_base_ood_points,
484489
combination_randomness_powers,
@@ -488,31 +493,31 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
488493
curr_randomness = combination_randomness_powers + WHIR_NUM_OOD_COMMIT * DIM
489494

490495
eq_memory_and_acc_point = eq_mle_extension(
491-
folding_randomness_global + (WHIR_N_VARS - log_memory) * DIM,
496+
folding_randomness_global + (stacked_n_vars - log_memory) * DIM,
492497
memory_and_acc_point,
493498
log_memory,
494499
)
495-
prefix_memory = multilinear_location_prefix(0, WHIR_N_VARS - log_memory, folding_randomness_global)
500+
prefix_memory = multilinear_location_prefix(0, stacked_n_vars - log_memory, folding_randomness_global)
496501
s = add_extension_ret(
497502
s,
498503
mul_extension_ret(mul_extension_ret(curr_randomness, prefix_memory), eq_memory_and_acc_point),
499504
)
500505
curr_randomness += DIM
501506

502-
prefix_acc_memory = multilinear_location_prefix(1, WHIR_N_VARS - log_memory, folding_randomness_global)
507+
prefix_acc_memory = multilinear_location_prefix(1, stacked_n_vars - log_memory, folding_randomness_global)
503508
s = add_extension_ret(
504509
s,
505510
mul_extension_ret(mul_extension_ret(curr_randomness, prefix_acc_memory), eq_memory_and_acc_point),
506511
)
507512
curr_randomness += DIM
508513

509514
eq_pub_mem = eq_mle_extension(
510-
folding_randomness_global + (WHIR_N_VARS - inner_public_memory_log_size) * DIM,
515+
folding_randomness_global + (stacked_n_vars - inner_public_memory_log_size) * DIM,
511516
public_memory_random_point,
512517
inner_public_memory_log_size,
513518
)
514519
prefix_pub_mem = multilinear_location_prefix(
515-
0, WHIR_N_VARS - inner_public_memory_log_size, folding_randomness_global
520+
0, stacked_n_vars - inner_public_memory_log_size, folding_randomness_global
516521
)
517522
s = add_extension_ret(
518523
s,
@@ -523,13 +528,13 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
523528
offset = powers_of_two(log_memory) * 2 # memory and acc_memory
524529

525530
eq_bytecode_acc = eq_mle_extension(
526-
folding_randomness_global + (WHIR_N_VARS - log2_ceil(GUEST_BYTECODE_LEN)) * DIM,
531+
folding_randomness_global + (stacked_n_vars - log2_ceil(GUEST_BYTECODE_LEN)) * DIM,
527532
bytecode_and_acc_point,
528533
log2_ceil(GUEST_BYTECODE_LEN),
529534
)
530535
prefix_bytecode_acc = multilinear_location_prefix(
531536
offset / 2 ** log2_ceil(GUEST_BYTECODE_LEN),
532-
WHIR_N_VARS - log2_ceil(GUEST_BYTECODE_LEN),
537+
stacked_n_vars - log2_ceil(GUEST_BYTECODE_LEN),
533538
folding_randomness_global,
534539
)
535540
s = add_extension_ret(
@@ -541,36 +546,36 @@ def recursion(inner_public_memory_log_size, inner_public_memory, proof_transcrip
541546

542547
prefix_pc_start = multilinear_location_prefix(
543548
offset + COL_PC * powers_of_two(log_n_cycles),
544-
WHIR_N_VARS,
549+
stacked_n_vars,
545550
folding_randomness_global,
546551
)
547552
s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_start))
548553
curr_randomness += DIM
549554

550555
prefix_pc_end = multilinear_location_prefix(
551556
offset + (COL_PC + 1) * powers_of_two(log_n_cycles) - 1,
552-
WHIR_N_VARS,
557+
stacked_n_vars,
553558
folding_randomness_global,
554559
)
555560
s = add_extension_ret(s, mul_extension_ret(curr_randomness, prefix_pc_end))
556561
curr_randomness += DIM
557562

558563
for table_index in unroll(0, N_TABLES):
559-
log_n_rows = table_dims[table_index]
560-
n_rows = powers_of_two(log_n_rows)
564+
log_n_rows = table_log_heights[table_index]
565+
n_rows = table_heights[table_index]
561566
total_num_cols = NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index]
562567
for i in unroll(0, len(pcs_points[table_index])):
563568
point = pcs_points[table_index][i]
564569
eq_factor = eq_mle_extension(
565570
point,
566-
folding_randomness_global + (WHIR_N_VARS - log_n_rows) * DIM,
571+
folding_randomness_global + (stacked_n_vars - log_n_rows) * DIM,
567572
log_n_rows,
568573
)
569574
for j in unroll(0, total_num_cols):
570575
if len(pcs_values[table_index][i][j]) == 1:
571576
prefix = multilinear_location_prefix(
572577
offset / n_rows + j,
573-
WHIR_N_VARS - log_n_rows,
578+
stacked_n_vars - log_n_rows,
574579
folding_randomness_global,
575580
)
576581
s = add_extension_ret(
@@ -675,6 +680,17 @@ def verify_gkr_quotient_step(fs: Mut, n_vars, point, claim_num, claim_den):
675680
return fs, postponed_point, new_claim_num, new_claim_den
676681

677682

683+
@inline
684+
def compute_stacked_n_vars(log_memory, log_bytecode_padded, tables_heights):
685+
total: Mut = powers_of_two(log_memory + 1) # memory + acc_memory
686+
total += powers_of_two(log_bytecode_padded)
687+
for table_index in unroll(0, N_TABLES):
688+
n_rows = tables_heights[table_index]
689+
total += n_rows * (NUM_COLS_F_AIR[table_index] + DIM * NUM_COLS_EF_AIR[table_index])
690+
debug_assert(30 - 24 < MIN_LOG_N_ROWS_PER_TABLE) # cf log2_ceil
691+
return MIN_LOG_N_ROWS_PER_TABLE + log2_ceil_runtime(total / 2**MIN_LOG_N_ROWS_PER_TABLE)
692+
693+
678694
def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, bus_beta, logup_alphas_eq_poly):
679695
res: Imu
680696
debug_assert(table_index < 3)

crates/rec_aggregation/src/recursion.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def main():
6969
)
7070
.unwrap();
7171

72-
let outer_whir_config = WhirConfig::<EF>::new(&inner_whir_config, proof_to_prove.first_whir_n_vars);
72+
let outer_whir_config = WhirConfig::<EF>::new(&inner_whir_config, proof_to_prove.whir_n_vars);
7373

7474
// let guest_program_commitment = {
7575
// let mut prover_state = build_prover_state();
@@ -400,7 +400,6 @@ pub(crate) fn whir_recursion_placeholder_replacements(whir_config: &WhirConfig<E
400400
format!("WHIR_FOLDING_FACTORS{}", end),
401401
format!("[{}]", folding_factors.join(", ")),
402402
);
403-
replacements.insert(format!("WHIR_N_VARS{}", end), whir_config.num_variables.to_string());
404403
replacements.insert(
405404
format!("WHIR_LOG_INV_RATE{}", end),
406405
whir_config.starting_log_inv_rate.to_string(),

0 commit comments

Comments
 (0)