obvhs/cwbvh/
simd.rs

1use glam::*;
2#[cfg(target_arch = "x86")]
3use std::arch::x86::*;
4#[cfg(target_arch = "x86_64")]
5use std::arch::x86_64::*;
6
7use crate::{
8    cwbvh::{
9        CwBvhNode,
10        node::{EPSILON, extract_byte64},
11    },
12    ray::Ray,
13};
14
15impl CwBvhNode {
16    #[inline(always)]
17    pub fn intersect_ray_simd(&self, ray: &Ray, oct_inv4: u32) -> u32 {
18        let adj_ray_dir_inv = self.compute_extent() * ray.inv_direction;
19        let adj_ray_origin = (Vec3A::from(self.p) - ray.origin) * ray.inv_direction;
20        let mut hit_mask = 0u32;
21        unsafe {
22            let adj_ray_dir_inv_x = _mm_set1_ps(adj_ray_dir_inv.x);
23            let adj_ray_dir_inv_y = _mm_set1_ps(adj_ray_dir_inv.y);
24            let adj_ray_dir_inv_z = _mm_set1_ps(adj_ray_dir_inv.z);
25
26            let adj_ray_orig_x = _mm_set1_ps(adj_ray_origin.x);
27            let adj_ray_orig_y = _mm_set1_ps(adj_ray_origin.y);
28            let adj_ray_orig_z = _mm_set1_ps(adj_ray_origin.z);
29
30            let rdx = ray.direction.x < 0.0;
31            let rdy = ray.direction.y < 0.0;
32            let rdz = ray.direction.z < 0.0;
33
34            let (child_bits8, bit_index8) = self.get_child_and_index_bits(oct_inv4);
35
36            #[inline(always)]
37            fn get_q(v: &[u8; 8], i: usize) -> __m128 {
38                // get_q is the most expensive part of intersect_simd
39                // Tried version with _mm_cvtepu8_epi32 and _mm_cvtepi32_ps, it was a lot slower.
40                // Tried transmuting v into a u64 and bit shifting, it was a lot slower.
41                unsafe {
42                    _mm_set_ps(
43                        *v.get_unchecked(i * 4 + 3) as f32,
44                        *v.get_unchecked(i * 4 + 2) as f32,
45                        *v.get_unchecked(i * 4 + 1) as f32,
46                        *v.get_unchecked(i * 4) as f32,
47                    )
48                }
49            }
50
51            // Intersect 4 aabbs at a time:
52            for i in 0..2 {
53                // It's possible to select hi/lo outside the loop with child_min_x, etc... but that seems quite a bit slower
54                // using _mm_blendv_ps or similar instead of `if rdx`, etc... is slower
55
56                // Interleaving x, y, z like this is slightly faster than loading all at once. Tried using _mm_prefetch without luck
57                let q_lo_x = get_q(&self.child_min_x, i);
58                let q_hi_x = get_q(&self.child_max_x, i);
59                let x_min = if rdx { q_hi_x } else { q_lo_x };
60                let x_max = if rdx { q_lo_x } else { q_hi_x };
61                // Tried using _mm_fmadd_ps, it was a lot slower
62                let tmin_x = _mm_add_ps(_mm_mul_ps(x_min, adj_ray_dir_inv_x), adj_ray_orig_x);
63                let tmax_x = _mm_add_ps(_mm_mul_ps(x_max, adj_ray_dir_inv_x), adj_ray_orig_x);
64
65                let q_lo_y = get_q(&self.child_min_y, i);
66                let q_hi_y = get_q(&self.child_max_y, i);
67                let y_min = if rdy { q_hi_y } else { q_lo_y };
68                let y_max = if rdy { q_lo_y } else { q_hi_y };
69                let tmin_y = _mm_add_ps(_mm_mul_ps(y_min, adj_ray_dir_inv_y), adj_ray_orig_y);
70                let tmax_y = _mm_add_ps(_mm_mul_ps(y_max, adj_ray_dir_inv_y), adj_ray_orig_y);
71
72                let q_lo_z = get_q(&self.child_min_z, i);
73                let q_hi_z = get_q(&self.child_max_z, i);
74                let z_min = if rdz { q_hi_z } else { q_lo_z };
75                let z_max = if rdz { q_lo_z } else { q_hi_z };
76                let tmin_z = _mm_add_ps(_mm_mul_ps(z_min, adj_ray_dir_inv_z), adj_ray_orig_z);
77                let tmax_z = _mm_add_ps(_mm_mul_ps(z_max, adj_ray_dir_inv_z), adj_ray_orig_z);
78
79                // Tried using _mm_fmadd_ps, it was a lot slower
80                // Compute intersection
81                let tmin = _mm_max_ps(tmin_x, _mm_max_ps(tmin_y, tmin_z));
82                let tmax = _mm_min_ps(tmax_x, _mm_min_ps(tmax_y, tmax_z));
83                let tmin = _mm_max_ps(tmin, _mm_set1_ps(EPSILON)); //ray.tmin?
84                let tmax = _mm_min_ps(tmax, _mm_set1_ps(ray.tmax));
85
86                let intersected = _mm_cmple_ps(tmin, tmax);
87                let mask = _mm_movemask_ps(intersected);
88
89                for j in 0..4 {
90                    let offset = i * 4 + j;
91                    if (mask & (1 << j)) != 0 {
92                        let child_bits = extract_byte64(child_bits8, offset);
93                        let bit_index = extract_byte64(bit_index8, offset);
94                        hit_mask |= child_bits << bit_index;
95                    }
96                }
97            }
98        }
99        hit_mask
100    }
101}