use core::cmp::max;
use core::ops::{Add, Mul, Neg, Sub};
use ff::Field;
use std::iter::{Product, Sum};
pub trait Variable: Clone + Copy + std::fmt::Debug + std::fmt::Display + Eq + PartialEq {
fn degree(&self) -> usize;
fn complexity(&self) -> usize {
0
}
fn write_identifier<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()>;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Expression<F, V: Variable> {
Constant(F),
Var(V),
Negated(Box<Expression<F, V>>),
Sum(Box<Expression<F, V>>, Box<Expression<F, V>>),
Product(Box<Expression<F, V>>, Box<Expression<F, V>>),
}
impl<F: Field, V: Variable> Expression<F, V> {
#[allow(clippy::too_many_arguments)]
pub fn evaluate<T>(
&self,
constant: &impl Fn(F) -> T,
var: &impl Fn(V) -> T,
negated: &impl Fn(T) -> T,
sum: &impl Fn(T, T) -> T,
product: &impl Fn(T, T) -> T,
) -> T {
match self {
Expression::Constant(scalar) => constant(*scalar),
Expression::Var(v) => var(*v),
Expression::Negated(a) => {
let a = a.evaluate(constant, var, negated, sum, product);
negated(a)
}
Expression::Sum(a, b) => {
let a = a.evaluate(constant, var, negated, sum, product);
let b = b.evaluate(constant, var, negated, sum, product);
sum(a, b)
}
Expression::Product(a, b) => {
let a = a.evaluate(constant, var, negated, sum, product);
let b = b.evaluate(constant, var, negated, sum, product);
product(a, b)
}
}
}
fn write_identifier<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
match self {
Expression::Constant(scalar) => write!(writer, "{scalar:?}"),
Expression::Var(v) => v.write_identifier(writer),
Expression::Negated(a) => {
writer.write_all(b"(-")?;
a.write_identifier(writer)?;
writer.write_all(b")")
}
Expression::Sum(a, b) => {
writer.write_all(b"(")?;
a.write_identifier(writer)?;
writer.write_all(b"+")?;
b.write_identifier(writer)?;
writer.write_all(b")")
}
Expression::Product(a, b) => {
writer.write_all(b"(")?;
a.write_identifier(writer)?;
writer.write_all(b"*")?;
b.write_identifier(writer)?;
writer.write_all(b")")
}
}
}
pub fn identifier(&self) -> String {
let mut cursor = std::io::Cursor::new(Vec::new());
self.write_identifier(&mut cursor).unwrap();
String::from_utf8(cursor.into_inner()).unwrap()
}
pub fn degree(&self) -> usize {
use Expression::*;
match self {
Constant(_) => 0,
Var(v) => v.degree(),
Negated(poly) => poly.degree(),
Sum(a, b) => max(a.degree(), b.degree()),
Product(a, b) => a.degree() + b.degree(),
}
}
pub fn complexity(&self) -> usize {
match self {
Expression::Constant(_) => 0,
Expression::Var(v) => v.complexity(),
Expression::Negated(poly) => poly.complexity() + 5,
Expression::Sum(a, b) => a.complexity() + b.complexity() + 15,
Expression::Product(a, b) => a.complexity() + b.complexity() + 30,
}
}
}
impl<F: Field, V: Variable> Neg for Expression<F, V> {
type Output = Expression<F, V>;
fn neg(self) -> Self::Output {
Expression::Negated(Box::new(self))
}
}
impl<F: Field, V: Variable> Add for Expression<F, V> {
type Output = Expression<F, V>;
fn add(self, rhs: Expression<F, V>) -> Expression<F, V> {
Expression::Sum(Box::new(self), Box::new(rhs))
}
}
impl<F: Field, V: Variable> Sub for Expression<F, V> {
type Output = Expression<F, V>;
fn sub(self, rhs: Expression<F, V>) -> Expression<F, V> {
Expression::Sum(Box::new(self), Box::new(-rhs))
}
}
impl<F: Field, V: Variable> Mul for Expression<F, V> {
type Output = Expression<F, V>;
fn mul(self, rhs: Expression<F, V>) -> Expression<F, V> {
Expression::Product(Box::new(self), Box::new(rhs))
}
}
impl<F: Field, V: Variable> Mul<F> for Expression<F, V> {
type Output = Expression<F, V>;
fn mul(self, rhs: F) -> Expression<F, V> {
Expression::Product(Box::new(self), Box::new(Expression::Constant(rhs)))
}
}
impl<F: Field, V: Variable> Sum<Self> for Expression<F, V> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|acc, x| acc + x)
.unwrap_or(Expression::Constant(F::ZERO))
}
}
impl<F: Field, V: Variable> Product<Self> for Expression<F, V> {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|acc, x| acc * x)
.unwrap_or(Expression::Constant(F::ONE))
}
}
#[cfg(test)]
mod tests {
use crate::circuit::ExpressionMid;
use halo2curves::bn256::Fr;
#[test]
fn iter_sum() {
let exprs: Vec<ExpressionMid<Fr>> = vec![
ExpressionMid::Constant(1.into()),
ExpressionMid::Constant(2.into()),
ExpressionMid::Constant(3.into()),
];
let happened: ExpressionMid<Fr> = exprs.into_iter().sum();
let expected: ExpressionMid<Fr> = ExpressionMid::Sum(
Box::new(ExpressionMid::Sum(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
)),
Box::new(ExpressionMid::Constant(3.into())),
);
assert_eq!(happened, expected);
}
#[test]
fn iter_product() {
let exprs: Vec<ExpressionMid<Fr>> = vec![
ExpressionMid::Constant(1.into()),
ExpressionMid::Constant(2.into()),
ExpressionMid::Constant(3.into()),
];
let happened: ExpressionMid<Fr> = exprs.into_iter().product();
let expected: ExpressionMid<Fr> = ExpressionMid::Product(
Box::new(ExpressionMid::Product(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
)),
Box::new(ExpressionMid::Constant(3.into())),
);
assert_eq!(happened, expected);
}
}