use std::{collections::HashMap, iter};
use crate::plonk::Error;
use group::Curve;
use halo2_middleware::ff::Field;
use halo2_middleware::zal::{impls::PlonkEngine, traits::MsmAccel};
use rand_chacha::ChaCha20Rng;
use rand_core::{RngCore, SeedableRng};
use super::Argument;
use crate::{
arithmetic::{eval_polynomial, parallelize, CurveAffine},
multicore::current_num_threads,
plonk::ChallengeX,
poly::{
commitment::{Blind, ParamsProver},
Coeff, EvaluationDomain, ExtendedLagrangeCoeff, Polynomial, ProverQuery,
},
transcript::{EncodedChallenge, TranscriptWrite},
};
pub(in crate::plonk) struct Committed<C: CurveAffine> {
random_poly: Polynomial<C::Scalar, Coeff>,
}
pub(in crate::plonk) struct Constructed<C: CurveAffine> {
h_pieces: Vec<Polynomial<C::Scalar, Coeff>>,
committed: Committed<C>,
}
pub(in crate::plonk) struct Evaluated<C: CurveAffine> {
h_poly: Polynomial<C::Scalar, Coeff>,
committed: Committed<C>,
}
impl<C: CurveAffine> Argument<C> {
pub(in crate::plonk) fn commit<
P: ParamsProver<C>,
E: EncodedChallenge<C>,
R: RngCore,
T: TranscriptWrite<C, E>,
>(
engine: &impl MsmAccel<C>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
mut rng: R,
transcript: &mut T,
) -> Result<Committed<C>, Error> {
let n = 1usize << domain.k() as usize;
let mut rand_vec = vec![C::Scalar::ZERO; n];
let num_threads = current_num_threads();
let chunk_size = n / num_threads;
let thread_seeds = (0..)
.step_by(chunk_size + 1)
.take(n % num_threads)
.chain(
(chunk_size != 0)
.then(|| ((n % num_threads) * (chunk_size + 1)..).step_by(chunk_size))
.into_iter()
.flatten(),
)
.take(num_threads)
.zip(iter::repeat_with(|| {
let mut seed = [0u8; 32];
rng.fill_bytes(&mut seed);
ChaCha20Rng::from_seed(seed)
}))
.collect::<HashMap<_, _>>();
parallelize(&mut rand_vec, |chunk, offset| {
let mut rng = thread_seeds[&offset].clone();
chunk
.iter_mut()
.for_each(|v| *v = C::Scalar::random(&mut rng));
});
let random_poly: Polynomial<C::Scalar, Coeff> = domain.coeff_from_vec(rand_vec);
let random_blind = Blind(C::Scalar::random(rng));
let c = params
.commit(engine, &random_poly, random_blind)
.to_affine();
transcript.write_point(c)?;
Ok(Committed { random_poly })
}
}
impl<C: CurveAffine> Committed<C> {
pub(in crate::plonk) fn construct<
P: ParamsProver<C>,
E: EncodedChallenge<C>,
R: RngCore,
T: TranscriptWrite<C, E>,
M: MsmAccel<C>,
>(
self,
engine: &PlonkEngine<C, M>,
params: &P,
domain: &EvaluationDomain<C::Scalar>,
h_poly: Polynomial<C::Scalar, ExtendedLagrangeCoeff>,
mut rng: R,
transcript: &mut T,
) -> Result<Constructed<C>, Error> {
let h_poly = domain.divide_by_vanishing_poly(h_poly);
let mut h_poly = domain.extended_to_coeff(h_poly);
h_poly.truncate(((1u64 << domain.k()) as usize) * domain.get_quotient_poly_degree());
let h_pieces = h_poly
.chunks_exact(params.n() as usize)
.map(|v| domain.coeff_from_vec(v.to_vec()))
.collect::<Vec<_>>();
drop(h_poly);
let h_blinds: Vec<_> = h_pieces
.iter()
.map(|_| Blind(C::Scalar::random(&mut rng)))
.collect();
let h_commitments = {
let h_commitments_projective: Vec<_> = h_pieces
.iter()
.zip(h_blinds.iter())
.map(|(h_piece, blind)| params.commit(&engine.msm_backend, h_piece, *blind))
.collect();
let mut h_commitments = vec![C::identity(); h_commitments_projective.len()];
C::Curve::batch_normalize(&h_commitments_projective, &mut h_commitments);
h_commitments
};
for c in h_commitments {
transcript.write_point(c)?;
}
Ok(Constructed {
h_pieces,
committed: self,
})
}
}
impl<C: CurveAffine> Constructed<C> {
pub(in crate::plonk) fn evaluate<E: EncodedChallenge<C>, T: TranscriptWrite<C, E>>(
self,
x: ChallengeX<C>,
xn: C::Scalar,
domain: &EvaluationDomain<C::Scalar>,
transcript: &mut T,
) -> Result<Evaluated<C>, Error> {
let h_poly = self
.h_pieces
.iter()
.rev()
.fold(domain.empty_coeff(), |acc, eval| acc * xn + eval);
let random_eval = eval_polynomial(&self.committed.random_poly, *x);
transcript.write_scalar(random_eval)?;
Ok(Evaluated {
h_poly,
committed: self.committed,
})
}
}
impl<C: CurveAffine> Evaluated<C> {
pub(in crate::plonk) fn open(
&self,
x: ChallengeX<C>,
) -> impl Iterator<Item = ProverQuery<'_, C>> + Clone {
iter::empty()
.chain(Some(ProverQuery {
point: *x,
poly: &self.h_poly,
}))
.chain(Some(ProverQuery {
point: *x,
poly: &self.committed.random_poly,
}))
}
}