Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/air/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ where
columns_up_down_group_packed,
air,
&extra_data,
Some((zerocheck_challenges, None)),
Some(zerocheck_challenges),
prover_state,
virtual_column_statement
.as_ref()
Expand Down
3 changes: 3 additions & 0 deletions crates/backend/sumcheck/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![cfg_attr(not(test), warn(unused_crate_dependencies))]

mod split_eq;
pub use split_eq::*;

mod prove;
pub use prove::*;

Expand Down
58 changes: 21 additions & 37 deletions crates/backend/sumcheck/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn sumcheck_prove<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
multilinears_f: M,
computation: &SC,
extra_data: &SC::ExtraData,
eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
eq_factor: Option<Vec<EF>>,
prover_state: &mut impl FSProver<EF>,
sum: EF,
store_intermediate_foldings: bool,
Expand Down Expand Up @@ -39,7 +39,7 @@ pub fn sumcheck_fold_and_prove<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
prev_folding_factor: Option<EF>,
computation: &SC,
extra_data: &SC::ExtraData,
eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
eq_factor: Option<Vec<EF>>,
prover_state: &mut impl FSProver<EF>,
sum: EF,
store_intermediate_foldings: bool,
Expand Down Expand Up @@ -88,7 +88,7 @@ pub fn sumcheck_prove_many_rounds<'a, EF, SC, M: Into<MleGroup<'a, EF>>>(
mut prev_folding_factor: Option<EF>,
computation: &SC,
extra_data: &SC::ExtraData,
mut eq_factor: Option<(Vec<EF>, Option<MleOwned<EF>>)>, // (a, b, c ...), eq_poly(b, c, ...)
mut eq_factor: Option<Vec<EF>>,
prover_state: &mut impl FSProver<EF>,
mut sum: EF,
mut missing_mul_factors: Option<EF>,
Expand All @@ -102,49 +102,31 @@ where
SC::ExtraData: AlphaPowers<EF>,
{
let mut multilinears: MleGroup<'a, EF> = multilinears_f.into();

let mut eq_factor: Option<(Vec<EF>, MleOwned<EF>)> = eq_factor.take().map(|(eq_point, eq_mle)| {
let eq_mle = eq_mle.unwrap_or_else(|| {
let eval_eq_ext = eval_eq(&eq_point[1..]);
if multilinears.by_ref().is_packed() {
MleOwned::ExtensionPacked(pack_extension(&eval_eq_ext))
} else {
MleOwned::Extension(eval_eq_ext)
}
});
(eq_point, eq_mle)
});

let mut n_vars = multilinears.by_ref().n_vars();
if prev_folding_factor.is_some() {
n_vars -= 1;
}
if let Some((eq_point, eq_mle)) = &eq_factor {

let mut eq_factor_and_split: Option<(Vec<EF>, SplitEq<EF>)> = eq_factor.take().map(|eq_point| {
assert_eq!(eq_point.len(), n_vars);
assert_eq!(eq_mle.by_ref().n_vars(), eq_point.len() - 1);
if eq_mle.by_ref().is_packed() && !multilinears.is_packed() {
assert!(eq_point.len() < packing_log_width::<EF>());
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();
}
}
let split_eq = SplitEq::new(&eq_point[1..]);
(eq_point, split_eq)
});

let mut challenges = Vec::new();
for _ in 0..n_rounds {
// If Packing is enabled, and there are too little variables, we unpack everything:
if multilinears.by_ref().is_packed() && n_vars <= 1 + packing_log_width::<EF>() {
// unpack
multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into();

if let Some((_, eq_mle)) = &mut eq_factor {
*eq_mle = eq_mle.by_ref().unpack().as_owned_or_clone();
}
// SplitEq handles unpacking transparently via get_unpacked
}

let ps = compute_and_send_polynomial(
&mut multilinears,
prev_folding_factor,
computation,
&eq_factor,
&eq_factor_and_split,
extra_data,
prover_state,
sum,
Expand All @@ -157,7 +139,7 @@ where
prev_folding_factor = on_challenge_received(
&mut multilinears,
&mut n_vars,
&mut eq_factor,
&mut eq_factor_and_split,
&mut sum,
&mut missing_mul_factors,
challenge,
Expand All @@ -178,7 +160,7 @@ fn compute_and_send_polynomial<'a, EF, SC>(
multilinears: &mut MleGroup<'a, EF>,
prev_folding_factor: Option<EF>,
computation: &SC,
eq_factor: &Option<(Vec<EF>, MleOwned<EF>)>, // (a, b, c ...), eq_poly(b, c, ...)
eq_factor_and_split: &Option<(Vec<EF>, SplitEq<EF>)>,
extra_data: &SC::ExtraData,
prover_state: &mut impl FSProver<EF>,
sum: EF,
Expand All @@ -196,8 +178,10 @@ where
let computation_degree = computation.degree();

let sc_params = SumcheckComputeParams {
eq_mle: eq_factor.as_ref().map(|(_, eq_mle)| eq_mle),
first_eq_factor: eq_factor.as_ref().map(|(first_eq_factor, _)| first_eq_factor[0]),
split_eq: eq_factor_and_split.as_ref().map(|(_, split_eq)| split_eq),
first_eq_factor: eq_factor_and_split
.as_ref()
.map(|(first_eq_factor, _)| first_eq_factor[0]),
computation,
extra_data,
missing_mul_factor,
Expand All @@ -217,7 +201,7 @@ where
None => sumcheck_compute(&multilinears.by_ref(), sc_params, computation_degree),
});

let p_at_1 = if let Some((eq_factor, _)) = eq_factor {
let p_at_1 = if let Some((eq_factor, _)) = eq_factor_and_split {
(sum - (EF::ONE - eq_factor[0]) * p_evals[0]) / eq_factor[0]
} else {
sum - p_evals[0]
Expand All @@ -232,7 +216,7 @@ where
.collect::<Vec<_>>(),
)
.unwrap();
let eq_alpha = eq_factor.as_ref().map(|(p, _)| p[0]);
let eq_alpha = eq_factor_and_split.as_ref().map(|(p, _)| p[0]);
prover_state.add_sumcheck_polynomial(&poly.coeffs, eq_alpha);
poly
}
Expand All @@ -241,7 +225,7 @@ where
fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
multilinears: &mut MleGroup<'a, EF>,
n_vars: &mut usize,
eq_factor: &mut Option<(Vec<EF>, MleOwned<EF>)>, // (a, b, c ...), eq_poly(b, c, ...)
eq_factor: &mut Option<(Vec<EF>, SplitEq<EF>)>,
sum: &mut EF,
missing_mul_factor: &mut Option<EF>,
challenge: EF,
Expand All @@ -253,7 +237,7 @@ fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
*sum = p.evaluate(challenge);
*n_vars -= 1;

if let Some((eq_factor, eq_mle)) = eq_factor {
if let Some((eq_factor, split_eq)) = eq_factor {
// Multiply sum by eq(α_i, r_i) since the polynomial doesn't include the eq linear factor
let eq_eval = (EF::ONE - eq_factor[0]) * (EF::ONE - challenge) + eq_factor[0] * challenge;
*sum *= eq_eval;
Expand All @@ -262,7 +246,7 @@ fn on_challenge_received<'a, EF: ExtensionField<PF<EF>>>(
eq_eval * missing_mul_factor.unwrap_or(EF::ONE) / (EF::ONE - eq_factor.get(1).copied().unwrap_or_default()),
);
eq_factor.remove(0);
eq_mle.truncate(eq_mle.by_ref().packed_len() / 2);
split_eq.truncate_half();
}

if store_intermediate_foldings {
Expand Down
Loading
Loading