1use {
2    core::fmt::{self, Debug},
3    gpu_alloc_types::{MemoryPropertyFlags, MemoryType},
4};
5
6bitflags::bitflags! {
7    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10    #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
11    pub struct UsageFlags: u8 {
12        const FAST_DEVICE_ACCESS = 0x01;
15
16        const HOST_ACCESS = 0x02;
20
21        const DOWNLOAD = 0x04;
25
26        const UPLOAD = 0x08;
32
33        const TRANSIENT = 0x10;
39
40        const DEVICE_ADDRESS = 0x20;
43    }
44}
45
46#[derive(Clone, Copy, Debug)]
47struct MemoryForOneUsage {
48    mask: u32,
49    types: [u32; 32],
50    types_count: u32,
51}
52
53pub(crate) struct MemoryForUsage {
54    usages: [MemoryForOneUsage; 64],
55}
56
57impl Debug for MemoryForUsage {
58    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
59        fmt.debug_struct("MemoryForUsage")
60            .field("usages", &&self.usages[..])
61            .finish()
62    }
63}
64
65impl MemoryForUsage {
66    pub fn new(memory_types: &[MemoryType]) -> Self {
67        assert!(
68            memory_types.len() <= 32,
69            "Only up to 32 memory types supported"
70        );
71
72        let mut mfu = MemoryForUsage {
73            usages: [MemoryForOneUsage {
74                mask: 0,
75                types: [0; 32],
76                types_count: 0,
77            }; 64],
78        };
79
80        for usage in 0..64 {
81            mfu.usages[usage as usize] =
82                one_usage(UsageFlags::from_bits_truncate(usage), memory_types);
83        }
84
85        mfu
86    }
87
88    pub fn mask(&self, usage: UsageFlags) -> u32 {
91        self.usages[usage.bits() as usize].mask
92    }
93
94    pub fn types(&self, usage: UsageFlags) -> &[u32] {
97        let usage = &self.usages[usage.bits() as usize];
98        &usage.types[..usage.types_count as usize]
99    }
100}
101
102fn one_usage(usage: UsageFlags, memory_types: &[MemoryType]) -> MemoryForOneUsage {
103    let mut types = [0; 32];
104    let mut types_count = 0;
105
106    for (index, mt) in memory_types.iter().enumerate() {
107        if compatible(usage, mt.props) {
108            types[types_count as usize] = index as u32;
109            types_count += 1;
110        }
111    }
112
113    types[..types_count as usize]
114        .sort_unstable_by_key(|&index| reverse_priority(usage, memory_types[index as usize].props));
115
116    let mask = types[..types_count as usize]
117        .iter()
118        .fold(0u32, |mask, index| mask | 1u32 << index);
119
120    MemoryForOneUsage {
121        mask,
122        types,
123        types_count,
124    }
125}
126
127fn compatible(usage: UsageFlags, flags: MemoryPropertyFlags) -> bool {
128    type Flags = MemoryPropertyFlags;
129    if flags.contains(Flags::LAZILY_ALLOCATED) || flags.contains(Flags::PROTECTED) {
130        false
132    } else if usage.intersects(UsageFlags::HOST_ACCESS | UsageFlags::UPLOAD | UsageFlags::DOWNLOAD)
133    {
134        flags.contains(Flags::HOST_VISIBLE)
136    } else {
137        true
138    }
139}
140
141fn reverse_priority(usage: UsageFlags, flags: MemoryPropertyFlags) -> u32 {
144    type Flags = MemoryPropertyFlags;
145
146    let device_local: bool = flags.contains(Flags::DEVICE_LOCAL)
149        ^ (usage.is_empty() || usage.contains(UsageFlags::FAST_DEVICE_ACCESS));
150
151    assert!(
152        flags.contains(Flags::HOST_VISIBLE)
153            || !usage
154                .intersects(UsageFlags::HOST_ACCESS | UsageFlags::UPLOAD | UsageFlags::DOWNLOAD)
155    );
156
157    let host_visible: bool = flags.contains(Flags::HOST_VISIBLE)
159        ^ usage.intersects(UsageFlags::HOST_ACCESS | UsageFlags::UPLOAD | UsageFlags::DOWNLOAD);
160
161    let host_cached: bool =
164        flags.contains(Flags::HOST_CACHED) ^ usage.contains(UsageFlags::DOWNLOAD);
165
166    let host_coherent: bool = flags.contains(Flags::HOST_COHERENT)
169        ^ (usage.intersects(UsageFlags::UPLOAD | UsageFlags::DOWNLOAD));
170
171    device_local as u32 * 8
173        + host_visible as u32 * 4
174        + host_cached as u32 * 2
175        + host_coherent as u32
176}