1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use crate::evm_circuit::util::{
    constraint_builder::EVMConstraintBuilder, math_gadget::*, CachedRegion,
};
use eth_types::Field;
use halo2_proofs::plonk::{Error, Expression};

/// Returns `1` when `lhs == rhs`, and returns `0` otherwise.
#[derive(Clone, Debug)]
pub struct IsEqualGadget<F> {
    is_zero: IsZeroGadget<F>,
}

impl<F: Field> IsEqualGadget<F> {
    pub(crate) fn construct(
        cb: &mut EVMConstraintBuilder<F>,
        lhs: Expression<F>,
        rhs: Expression<F>,
    ) -> Self {
        let is_zero = cb.is_zero(lhs - rhs);

        Self { is_zero }
    }

    pub(crate) fn expr(&self) -> Expression<F> {
        self.is_zero.expr()
    }

    pub(crate) fn assign(
        &self,
        region: &mut CachedRegion<'_, '_, F>,
        offset: usize,
        lhs: F,
        rhs: F,
    ) -> Result<F, Error> {
        self.is_zero.assign(region, offset, lhs - rhs)
    }
}

#[cfg(test)]
mod tests {
    use super::{test_util::*, *};
    use crate::evm_circuit::util::{
        constraint_builder::ConstrainBuilderCommon, CachedRegion, Cell,
    };
    use eth_types::*;
    use halo2_proofs::{halo2curves::bn256::Fr, plonk::Error};

    #[derive(Clone)]
    /// IsEqualGadgetTestContainer: require(a == b)
    struct IsEqualGadgetTestContainer<F> {
        eq_gadget: IsEqualGadget<F>,
        a: Cell<F>,
        b: Cell<F>,
    }

    impl<F: Field> MathGadgetContainer<F> for IsEqualGadgetTestContainer<F> {
        fn configure_gadget_container(cb: &mut EVMConstraintBuilder<F>) -> Self {
            let a = cb.query_cell();
            let b = cb.query_cell();
            let eq_gadget = IsEqualGadget::<F>::construct(cb, a.expr(), b.expr());
            cb.require_equal("Inputs must equal (a==b)", eq_gadget.expr(), 1.expr());
            IsEqualGadgetTestContainer { eq_gadget, a, b }
        }

        fn assign_gadget_container(
            &self,
            witnesses: &[Word],
            region: &mut CachedRegion<'_, '_, F>,
        ) -> Result<(), Error> {
            let a = witnesses[0].to_scalar().unwrap();
            let b = witnesses[1].to_scalar().unwrap();
            let offset = 0;

            self.a.assign(region, offset, Value::known(a))?;
            self.b.assign(region, offset, Value::known(b))?;
            self.eq_gadget.assign(region, offset, a, b)?;

            Ok(())
        }
    }

    #[test]
    fn test_isequal_0() {
        try_test!(
            IsEqualGadgetTestContainer<Fr>,
            [Word::from(0), Word::from(0)],
            true,
        );
    }

    #[test]
    fn test_isequal_1() {
        try_test!(
            IsEqualGadgetTestContainer<Fr>,
            [Word::from(1), Word::from(1)],
            true,
        );
    }

    #[test]
    fn test_isequal_1000() {
        try_test!(
            IsEqualGadgetTestContainer<Fr>,
            [Word::from(1000), Word::from(1000)],
            true,
        );
    }

    #[test]
    fn test_isequal_1_0() {
        try_test!(
            IsEqualGadgetTestContainer<Fr>,
            [Word::from(1), Word::from(0)],
            false,
        );
    }

    #[test]
    fn test_isequal_0_1() {
        try_test!(
            IsEqualGadgetTestContainer<Fr>,
            [Word::from(0), Word::from(1)],
            false,
        );
    }
}