zk_kit_smt/
utils.rs

1/// Converts a hexadecimal string to a binary string.
2///
3/// # Arguments
4///
5/// * `n` - The hexadecimal string to convert.
6///
7/// # Returns
8///
9/// The binary representation of the hexadecimal string.
10pub fn hex_to_bin(n: &str) -> String {
11    let mut chars = n.chars();
12    let first_char = chars.next().unwrap();
13    let mut bin = format!(
14        "{:b}",
15        u8::from_str_radix(&first_char.to_string(), 16).unwrap()
16    );
17
18    for c in chars {
19        bin += &format!("{:04b}", u8::from_str_radix(&c.to_string(), 16).unwrap());
20    }
21
22    bin
23}
24
25/// Converts a hexadecimal key to a path represented as a vector of usize.
26///
27/// For each key, it is possible to obtain an array of 256 padded bits.
28///
29/// # Arguments
30///
31/// * `key` - The hexadecimal key to convert.
32///
33/// # Returns
34///
35/// The path represented as a vector of usize.
36pub fn key_to_path(key: &str) -> Vec<usize> {
37    let bits = if let Ok(num) = u128::from_str_radix(key, 16) {
38        format!("{:b}", num)
39    } else {
40        hex_to_bin(key)
41    };
42
43    let padded_bits = format!("{:0>256}", bits).chars().rev().collect::<String>();
44    let bits_array = padded_bits
45        .chars()
46        .map(|c| c.to_digit(10).unwrap() as usize)
47        .collect();
48
49    bits_array
50}
51
52/// Returns the index of the last non-zero element in the array.
53///
54/// # Arguments
55///
56/// * `array` - The array of hexadecimal strings.
57///
58/// # Returns
59///
60/// The index of the last non-zero element in the array, or -1 if no non-zero element is found.
61pub fn get_index_of_last_non_zero_element(array: Vec<&str>) -> isize {
62    for (i, &item) in array.iter().enumerate().rev() {
63        if u128::from_str_radix(item, 16).unwrap_or(0) != 0 {
64            return i as isize;
65        }
66    }
67
68    -1
69}
70
71/// Returns the first common elements between two arrays.
72///
73/// # Arguments
74///
75/// * `array1` - The first array.
76/// * `array2` - The second array.
77///
78/// # Returns
79///
80/// The first common elements between the two arrays.
81pub fn get_first_common_elements<T: PartialEq + Clone>(array1: &[T], array2: &[T]) -> Vec<T> {
82    let min_length = std::cmp::min(array1.len(), array2.len());
83
84    for i in 0..min_length {
85        if array1[i] != array2[i] {
86            return array1[0..i].to_vec();
87        }
88    }
89
90    array1[0..min_length].to_vec()
91}
92
93/// Checks if a string is a valid hexadecimal string.
94///
95/// # Arguments
96///
97/// * `s` - The string to check.
98///
99/// # Returns
100///
101/// `true` if the string is a valid hexadecimal string, `false` otherwise.
102pub fn is_hexadecimal(s: &str) -> bool {
103    s.chars().all(|c| c.is_ascii_hexdigit())
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_hex_to_bin() {
112        assert_eq!(hex_to_bin("A"), "1010");
113        assert_eq!(hex_to_bin("F"), "1111");
114        assert_eq!(hex_to_bin("1A"), "11010");
115        assert_eq!(hex_to_bin("FF"), "11111111");
116        assert_eq!(hex_to_bin("12"), "10010");
117    }
118
119    #[test]
120    fn test_key_to_path() {
121        let path = key_to_path("17");
122        assert_eq!(path.len(), 256);
123        assert_eq!(&path[0..5], vec![1, 1, 1, 0, 1]);
124    }
125
126    #[test]
127    fn test_get_index_of_last_non_zero_element() {
128        assert_eq!(get_index_of_last_non_zero_element(vec![]), -1);
129        assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "0"]), -1);
130
131        assert_eq!(get_index_of_last_non_zero_element(vec!["0", "0", "1"]), 2);
132        assert_eq!(get_index_of_last_non_zero_element(vec!["0", "1", "0"]), 1);
133        assert_eq!(get_index_of_last_non_zero_element(vec!["1", "0", "0"]), 0);
134
135        assert_eq!(
136            get_index_of_last_non_zero_element(vec!["0", "1", "0", "1", "0"]),
137            3
138        );
139        assert_eq!(
140            get_index_of_last_non_zero_element(vec!["1", "0", "1", "0", "0"]),
141            2
142        );
143        assert_eq!(
144            get_index_of_last_non_zero_element(vec!["0", "0", "0", "1", "1"]),
145            4
146        );
147        assert_eq!(
148            get_index_of_last_non_zero_element(vec![
149                "0", "17", "3", "0", "3", "0", "3", "2", "0", "0"
150            ]),
151            7
152        )
153    }
154
155    #[test]
156    fn test_get_first_common_elements() {
157        assert_eq!(get_first_common_elements::<u32>(&[], &[]), vec![]);
158
159        assert_eq!(
160            get_first_common_elements(&[1, 2, 3], &[1, 2, 3, 4, 5]),
161            vec![1, 2, 3]
162        );
163        assert_eq!(
164            get_first_common_elements(&[1, 2, 3, 4, 5], &[1, 2, 3]),
165            vec![1, 2, 3]
166        );
167
168        assert_eq!(
169            get_first_common_elements(&[1, 2, 3], &[1, 2, 4]),
170            vec![1, 2]
171        );
172        assert_eq!(get_first_common_elements(&[1, 2, 3], &[4, 5, 6]), vec![]);
173    }
174
175    #[test]
176    fn test_is_hexadecimal() {
177        assert!(is_hexadecimal("be12"));
178        assert!(is_hexadecimal("ABCDEF"));
179        assert!(is_hexadecimal("1234567890abcdef"));
180
181        assert!(!is_hexadecimal("gbe12"));
182        assert!(!is_hexadecimal("123XYZ"));
183        assert!(!is_hexadecimal("abcdefg"));
184    }
185}