Skip to content

Commit da258ac

Browse files
committed
use coeffs form (instead of evals) for the final poly in whir (less verifier work)
1 parent 43dddbe commit da258ac

File tree

5 files changed

+92
-41
lines changed

5 files changed

+92
-41
lines changed

crates/backend/poly/src/evals.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,46 @@ impl<F: Field, EL: Borrow<[F]>> EvaluationsList<F> for EL {
4141
}
4242
}
4343

44+
pub fn evals_to_coeffs<F: PrimeCharacteristicRing + Copy>(data: &mut [F]) {
45+
let n = data.len();
46+
let mut half = 1;
47+
while half < n {
48+
for i in (0..n).step_by(2 * half) {
49+
for j in 0..half {
50+
data[i + j + half] -= data[i + j];
51+
}
52+
}
53+
half <<= 1;
54+
}
55+
bit_reverse_permutation(data);
56+
}
57+
58+
pub fn bit_reverse_permutation<T>(data: &mut [T]) {
59+
let n = data.len();
60+
let log_n = n.ilog2() as usize;
61+
for i in 0..n {
62+
let j = i.reverse_bits() >> (usize::BITS as usize - log_n);
63+
if i < j {
64+
data.swap(i, j);
65+
}
66+
}
67+
}
68+
69+
pub fn eval_multilinear_coeffs<F, EF>(coeffs: &[F], point: &[EF]) -> EF
70+
where
71+
F: Field,
72+
EF: ExtensionField<F>,
73+
{
74+
debug_assert_eq!(coeffs.len(), 1 << point.len());
75+
match point {
76+
[] => EF::from(coeffs[0]),
77+
[x, tail @ ..] => {
78+
let (c0, c1) = coeffs.split_at(coeffs.len() / 2);
79+
eval_multilinear_coeffs(c0, tail) + eval_multilinear_coeffs(c1, tail) * *x
80+
}
81+
}
82+
}
83+
4484
/// Multiply the polynomial by a scalar factor.
4585
#[must_use]
4686
pub fn scale_poly<F: Field, EF: ExtensionField<F>>(poly: &[F], factor: EF) -> Vec<EF> {

crates/rec_aggregation/utils.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,6 @@ def poly_eq_extension(point, n: Const):
7676
)
7777
return res + (2**n - 1) * DIM
7878

79-
80-
def poly_eq_base(point, n: Const):
81-
# Example: for n = 2: eq(x, y) = [(1 - x)(1 - y), (1 - x)y, x(1 - y), xy]
82-
83-
res = Array((2 ** (n + 1) - 1))
84-
res[0] = 1
85-
for s in unroll(0, n):
86-
p = point[n - 1 - s]
87-
for i in unroll(0, 2**s):
88-
res[2 ** (s + 1) - 1 + 2**s + i] = p * res[2**s - 1 + i]
89-
res[2 ** (s + 1) - 1 + i] = res[2**s - 1 + i] - res[2 ** (s + 1) - 1 + 2**s + i]
90-
return res + (2**n - 1)
91-
92-
9379
def eq_mle_extension(a, b, n):
9480
debug_assert(n < 30)
9581
debug_assert(0 < n)
@@ -207,6 +193,31 @@ def expand_from_univariate_ext(alpha, n):
207193
return res
208194

209195

196+
def univariate_eval_on_base(coeffs, alpha, n: Const):
197+
# coeffs= univariate poly of degree 2^n
198+
# alpha: base field element
199+
# -> evaluates it at (1, alpha, alpha^2, alpha^4, ..., alpha^(2^(n-1)))
200+
alpha_powers = Array(2**n)
201+
alpha_powers[0] = 1
202+
for i in unroll(0, 2**n - 1):
203+
alpha_powers[i + 1] = alpha_powers[i] * alpha
204+
result = Array(DIM)
205+
dot_product(alpha_powers, coeffs, result, 2**n, BE)
206+
return result
207+
208+
209+
def eval_multilinear_coeffs_rev(coeffs, point, n: Const):
210+
# Evaluate multilinear polynomial in coefficient form (bit-reversed) at point.
211+
basis = Array(2**n * DIM)
212+
set_to_one(basis)
213+
for k in unroll(0, n):
214+
for j in unroll(0, 2**k):
215+
mul_extension(basis + j * DIM, point + k * DIM, basis + (j + 2**k) * DIM)
216+
result = Array(DIM)
217+
dot_product(coeffs, basis, result, 2**n, EE)
218+
return result
219+
220+
210221
def dot_product_be_dynamic(a, b, res, n):
211222
debug_assert(n <= 256)
212223
match_range(n, range(1, 257), lambda i: dot_product(a, b, res, i, BE))

crates/rec_aggregation/whir.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def whir_open(
2424
claimed_sum: Mut,
2525
):
2626
n_rounds, n_final_vars, num_queries, num_oods, query_grinding_bits, folding_grinding = get_whir_params(n_vars, initial_log_inv_rate)
27-
n_final_coeffs = powers_of_two(n_final_vars)
2827
folding_factors = Array(n_rounds + 1)
2928
folding_factors[0] = WHIR_INITIAL_FOLDING_FACTOR
3029
for i in range(1, n_rounds + 1):
@@ -88,15 +87,8 @@ def whir_open(
8887

8988
final_circle_values = all_circle_values[n_rounds]
9089
for i in range(0, num_queries[n_rounds]):
91-
powers_of_2_rev = expand_from_univariate_base(final_circle_values[i], n_final_vars)
92-
poly_eq = match_range(n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: poly_eq_base(powers_of_2_rev, n))
93-
final_pol_evaluated_on_circle = Array(DIM)
94-
dot_product_be_dynamic(
95-
poly_eq,
96-
final_coeffcients,
97-
final_pol_evaluated_on_circle,
98-
n_final_coeffs,
99-
)
90+
alpha = final_circle_values[i]
91+
final_pol_evaluated_on_circle = match_range(n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: univariate_eval_on_base(final_coeffcients, alpha, n))
10092
copy_5(final_pol_evaluated_on_circle, final_folds + i * DIM)
10193

10294
fs, all_folding_randomness[n_rounds + 1], end_sum = sumcheck_verify(fs, n_final_vars, claimed_sum, 2)
@@ -158,9 +150,7 @@ def whir_open(
158150
)
159151
s = add_extension_ret(s, s7)
160152
s = add_extension_ret(summed_ood, s)
161-
poly_eq_final = poly_eq_extension_dynamic(all_folding_randomness[n_rounds + 1], n_final_vars)
162-
final_value = Array(DIM)
163-
dot_product_ee_dynamic(poly_eq_final, final_coeffcients, final_value, n_final_coeffs)
153+
final_value = match_range(n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n))
164154
# copy_5(mul_extension_ret(s, final_value), end_sum);
165155

166156
return fs, folding_randomness_global, s, final_value, end_sum

crates/whir/src/open.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,14 @@ where
184184
prover_state: &mut impl FSProver<EF>,
185185
round_state: &mut RoundState<EF>,
186186
) -> ProofResult<()> {
187-
// Directly send coefficients of the polynomial to the verifier.
188-
189-
prover_state.add_extension_scalars(&match &round_state.sumcheck_prover.evals {
187+
// Convert evaluations to coefficient form and send to the verifier.
188+
let mut coeffs = match &round_state.sumcheck_prover.evals {
190189
MleOwned::Extension(evals) => evals.clone(),
191190
MleOwned::ExtensionPacked(evals) => unpack_extension::<EF>(evals),
192191
_ => unreachable!(),
193-
});
192+
};
193+
evals_to_coeffs(&mut coeffs);
194+
prover_state.add_extension_scalars(&coeffs);
194195

195196
prover_state.pow_grinding(self.final_query_pow_bits);
196197

crates/whir/src/verify.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ where
161161

162162
// In the final round we receive the full polynomial instead of a commitment.
163163
let n_final_coeffs = 1 << self.n_vars_of_final_polynomial();
164-
let final_evaluations = verifier_state.next_extension_scalars_vec(n_final_coeffs)?;
164+
let final_coefficients = verifier_state.next_extension_scalars_vec(n_final_coeffs)?;
165165

166166
// Verify in-domain challenges on the previous commitment.
167167
let stir_constraints = self.verify_stir_challenges(
@@ -175,7 +175,7 @@ where
175175
// Verify stir constraints directly on final polynomial
176176
stir_constraints
177177
.iter()
178-
.all(|c| verify_constraint(c, &final_evaluations))
178+
.all(|c| verify_constraint_coeffs(c, &final_coefficients))
179179
.then_some(())
180180
.ok_or(ProofError::InvalidProof)
181181
.unwrap();
@@ -194,8 +194,10 @@ where
194194

195195
let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone());
196196

197-
// Check the final sumcheck evaluation
198-
let final_value = final_evaluations.evaluate(&final_sumcheck_randomness);
197+
// Check the final sumcheck evaluation (coefficient form, reversed point)
198+
let mut reversed_point = final_sumcheck_randomness.0.clone();
199+
reversed_point.reverse();
200+
let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point);
199201
if claimed_sum != evaluation_of_weights * final_value {
200202
panic!();
201203
}
@@ -417,12 +419,19 @@ where
417419
}
418420
}
419421

420-
fn verify_constraint<EF: Field>(constraint: &SparseStatement<EF>, poly: &[EF]) -> bool {
421-
// poly.evaluate(&constraint.point) == constraint.value
422-
constraint
423-
.values
424-
.iter()
425-
.all(|e| poly.evaluate_sparse(e.selector, &constraint.point) == e.value)
422+
fn verify_constraint_coeffs<EF: Field>(constraint: &SparseStatement<EF>, coeffs: &[EF]) -> bool {
423+
assert_eq!(constraint.selector_num_variables(), 0);
424+
let alpha = constraint.point[0];
425+
// Verify the point is expand_from_univariate(alpha, n): [alpha, alpha^2, alpha^4, ...]
426+
assert!(
427+
constraint
428+
.point
429+
.iter()
430+
.zip(constraint.point.iter().skip(1))
431+
.all(|(&a, &b)| a * a == b)
432+
);
433+
let univariate_eval = coeffs.iter().rfold(EF::ZERO, |acc, &c| acc * alpha + c);
434+
constraint.values.iter().all(|e| univariate_eval == e.value)
426435
}
427436

428437
/// The full vector of folding randomness values, in reverse round order.

0 commit comments

Comments
 (0)