Skip to content

Commit 057934e

Browse files
authored
Merge pull request #17 from Shnatsel/safe-simd-ycbcr
Safe AVX YCbCr
2 parents f4ab0f2 + f75f32b commit 057934e

File tree

1 file changed

+133
-54
lines changed

1 file changed

+133
-54
lines changed

src/avx2/ycbcr.rs

Lines changed: 133 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,40 @@ macro_rules! ycbcr_image_avx2 {
2020

2121
impl<'a> $name<'a> {
2222
#[target_feature(enable = "avx2")]
23-
unsafe fn fill_buffers_avx2(&self, y: u16, buffers: &mut [Vec<u8>; 4]) {
24-
unsafe fn load3(data: *const u8) -> __m256i {
23+
fn fill_buffers_avx2(&self, y: u16, buffers: &mut [Vec<u8>; 4]) {
24+
// TODO: this compiles to many separate scalar loads and could be optimized further.
25+
// But the gains are no more than 3% end to end and it doesn't seem to be worth the complexity.
26+
#[inline]
27+
#[target_feature(enable = "avx2")]
28+
fn load3(data: &[u8]) -> __m256i {
29+
_ = data[7 * $num_colors]; // dummy indexing operation up front to avoid bounds checks later
2530
_mm256_set_epi32(
26-
*data as i32,
27-
*data.offset(1 * $num_colors) as i32,
28-
*data.offset(2 * $num_colors) as i32,
29-
*data.offset(3 * $num_colors) as i32,
30-
*data.offset(4 * $num_colors) as i32,
31-
*data.offset(5 * $num_colors) as i32,
32-
*data.offset(6 * $num_colors) as i32,
33-
*data.offset(7 * $num_colors) as i32,
31+
data[0] as i32,
32+
data[1 * $num_colors] as i32,
33+
data[2 * $num_colors] as i32,
34+
data[3 * $num_colors] as i32,
35+
data[4 * $num_colors] as i32,
36+
data[5 * $num_colors] as i32,
37+
data[6 * $num_colors] as i32,
38+
data[7 * $num_colors] as i32,
3439
)
3540
}
3641

37-
let mut y_buffer = buffers[0].as_mut_ptr().add(buffers[0].len());
38-
buffers[0].set_len(buffers[0].len() + self.width() as usize);
39-
let mut cb_buffer = buffers[1].as_mut_ptr().add(buffers[1].len());
40-
buffers[1].set_len(buffers[1].len() + self.width() as usize);
41-
let mut cr_buffer = buffers[2].as_mut_ptr().add(buffers[2].len());
42-
buffers[2].set_len(buffers[2].len() + self.width() as usize);
42+
#[inline]
43+
#[target_feature(enable = "avx2")]
44+
fn avx_as_i32_array(data: __m256i) -> [i32; 8] {
45+
// Safety preconditions. Optimized away in release mode, no runtime cost.
46+
assert!(core::mem::size_of::<__m256i>() == core::mem::size_of::<[i32; 8]>());
47+
assert!(core::mem::align_of::<__m256i>() >= core::mem::align_of::<[i32; 8]>());
48+
// SAFETY: size and alignment preconditions checked above.
49+
// Both types are plain old data: no pointers, lifetimes, etc.
50+
unsafe { core::mem::transmute(data) }
51+
}
52+
53+
let [y_buffer, cb_buffer, cr_buffer, _] = buffers;
54+
y_buffer.reserve(self.width() as usize);
55+
cb_buffer.reserve(self.width() as usize);
56+
cr_buffer.reserve(self.width() as usize);
4357

4458
let ymulr = _mm256_set1_epi32(19595);
4559
let ymulg = _mm256_set1_epi32(38470);
@@ -53,17 +67,14 @@ macro_rules! ycbcr_image_avx2 {
5367
let crmulg = _mm256_set1_epi32(27439);
5468
let crmulb = _mm256_set1_epi32(5329);
5569

56-
let mut data = self
57-
.0
58-
.as_ptr()
59-
.offset((y as isize * self.1 as isize * $num_colors));
70+
let mut data = &self.0[(y as usize * self.1 as usize * $num_colors)..];
6071

6172
for _ in 0..self.width() / 8 {
62-
let r = load3(data.offset($o1));
63-
let g = load3(data.offset($o2));
64-
let b = load3(data.offset($o3));
73+
let r = load3(&data[$o1..]);
74+
let g = load3(&data[$o2..]);
75+
let b = load3(&data[$o3..]);
6576

66-
data = data.add($num_colors * 8);
77+
data = &data[($num_colors * 8)..];
6778

6879
let yr = _mm256_mullo_epi32(ymulr, r);
6980
let yg = _mm256_mullo_epi32(ymulg, g);
@@ -72,7 +83,10 @@ macro_rules! ycbcr_image_avx2 {
7283
let y = _mm256_add_epi32(_mm256_add_epi32(yr, yg), yb);
7384
let y = _mm256_add_epi32(y, _mm256_set1_epi32(0x7FFF));
7485
let y = _mm256_srli_epi32(y, 16);
75-
let y: [i32; 8] = core::mem::transmute(y);
86+
let y: [i32; 8] = avx_as_i32_array(y);
87+
let mut y: [u8; 8] = y.map(|x| x as u8);
88+
y.reverse();
89+
y_buffer.extend_from_slice(&y);
7690

7791
let cbr = _mm256_mullo_epi32(cbmulr, r);
7892
let cbg = _mm256_mullo_epi32(cbmulg, g);
@@ -82,7 +96,10 @@ macro_rules! ycbcr_image_avx2 {
8296
let cb = _mm256_add_epi32(cb, _mm256_set1_epi32(128 << 16));
8397
let cb = _mm256_add_epi32(cb, _mm256_set1_epi32(0x7FFF));
8498
let cb = _mm256_srli_epi32(cb, 16);
85-
let cb: [i32; 8] = core::mem::transmute(cb);
99+
let cb: [i32; 8] = avx_as_i32_array(cb);
100+
let mut cb: [u8; 8] = cb.map(|x| x as u8);
101+
cb.reverse();
102+
cb_buffer.extend_from_slice(&cb);
86103

87104
let crr = _mm256_mullo_epi32(crmulr, r);
88105
let crg = _mm256_mullo_epi32(crmulg, g);
@@ -92,38 +109,19 @@ macro_rules! ycbcr_image_avx2 {
92109
let cr = _mm256_add_epi32(cr, _mm256_set1_epi32(128 << 16));
93110
let cr = _mm256_add_epi32(cr, _mm256_set1_epi32(0x7FFF));
94111
let cr = _mm256_srli_epi32(cr, 16);
95-
let cr: [i32; 8] = core::mem::transmute(cr);
96-
97-
for y in y.iter().rev() {
98-
*y_buffer = *y as u8;
99-
y_buffer = y_buffer.offset(1);
100-
}
101-
102-
for cb in cb.iter().rev() {
103-
*cb_buffer = *cb as u8;
104-
cb_buffer = cb_buffer.offset(1);
105-
}
106-
107-
for cr in cr.iter().rev() {
108-
*cr_buffer = *cr as u8;
109-
cr_buffer = cr_buffer.offset(1);
110-
}
112+
let cr: [i32; 8] = avx_as_i32_array(cr);
113+
let mut cr: [u8; 8] = cr.map(|x| x as u8);
114+
cr.reverse();
115+
cr_buffer.extend_from_slice(&cr);
111116
}
112117

113118
for _ in 0..self.width() % 8 {
114-
let (y, cb, cr) =
115-
rgb_to_ycbcr(*data.offset($o1), *data.offset($o2), *data.offset($o3));
116-
117-
data = data.add($num_colors);
119+
let (y, cb, cr) = rgb_to_ycbcr(data[$o1], data[$o2], data[$o3]);
120+
data = &data[$num_colors..];
118121

119-
*y_buffer = y;
120-
y_buffer = y_buffer.offset(1);
121-
122-
*cb_buffer = cb;
123-
cb_buffer = cb_buffer.offset(1);
124-
125-
*cr_buffer = cr;
126-
cr_buffer = cr_buffer.offset(1);
122+
y_buffer.push(y);
123+
cb_buffer.push(cb);
124+
cr_buffer.push(cr);
127125
}
128126
}
129127
}
@@ -155,3 +153,84 @@ ycbcr_image_avx2!(RgbImageAVX2, 3, 0, 1, 2);
155153
ycbcr_image_avx2!(RgbaImageAVX2, 4, 0, 1, 2);
156154
ycbcr_image_avx2!(BgrImageAVX2, 3, 2, 1, 0);
157155
ycbcr_image_avx2!(BgraImageAVX2, 4, 2, 1, 0);
156+
157+
#[cfg(test)]
158+
mod tests {
159+
use super::*;
160+
use std::vec::Vec;
161+
162+
// A very basic linear congruential generator (LCG) to avoid external dependencies.
163+
pub struct SimpleRng {
164+
state: u64,
165+
}
166+
167+
impl SimpleRng {
168+
/// Create a new RNG with a given seed.
169+
pub fn new(seed: u64) -> Self {
170+
Self { state: seed }
171+
}
172+
173+
/// Generate the next random u64 value.
174+
pub fn next_u64(&mut self) -> u64 {
175+
// Constants from Numerical Recipes
176+
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
177+
self.state
178+
}
179+
180+
/// Generate a random byte in 0..=255
181+
pub fn next_byte(&mut self) -> u8 {
182+
(self.next_u64() & 0xFF) as u8
183+
}
184+
185+
/// Fill a Vec<u8> with random bytes of the given length.
186+
pub fn random_bytes(&mut self, len: usize) -> Vec<u8> {
187+
(0..len).map(|_| self.next_byte()).collect()
188+
}
189+
}
190+
191+
#[test]
192+
#[cfg(feature = "simd")]
193+
fn avx_matches_scalar_rgb() {
194+
// Do not run AVX2 test on machines without it
195+
if !std::is_x86_feature_detected!("avx2") {
196+
return;
197+
}
198+
let mut rng = SimpleRng::new(42);
199+
let width = 512 + 3; // power of two plus a bit to stress remainder handling
200+
let height = 1;
201+
let bpp = 3;
202+
203+
let input = rng.random_bytes(width * height * bpp); // power of two plus a bit to exercise remainder handling
204+
205+
let scalar_result: Vec<[u8; 3]> = input
206+
.chunks_exact(bpp)
207+
.map(|chunk| {
208+
let [r, g, b, ..] = chunk else { unreachable!() };
209+
let (y, cb, cr) = rgb_to_ycbcr(*r, *g, *b);
210+
[y, cb, cr]
211+
})
212+
.collect();
213+
214+
let mut buffers = [Vec::new(), Vec::new(), Vec::new(), Vec::new()];
215+
let avx_input = RgbImageAVX2(
216+
&input,
217+
width.try_into().unwrap(),
218+
height.try_into().unwrap(),
219+
);
220+
// SAFETY: we've checked above that AVX2 is present
221+
unsafe {
222+
avx_input.fill_buffers_avx2(0, &mut buffers);
223+
}
224+
225+
for i in 0..3 {
226+
assert_eq!(buffers[i].len(), input.len() / 3);
227+
}
228+
229+
for (i, pixel) in scalar_result.iter().copied().enumerate() {
230+
let avx_pixel: [u8; 3] = [buffers[0][i], buffers[1][i], buffers[2][i]];
231+
if pixel != avx_pixel {
232+
panic!("Mismatch at index {i}: scalar result is {pixel:?}, avx result is {avx_pixel:?}");
233+
}
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)