Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 51 additions & 36 deletions src/unicode_string/str.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Colin Finck <colin@reactos.org>
// Copyright 2023-2026 Colin Finck <colin@reactos.org>
// SPDX-License-Identifier: MIT OR Apache-2.0

use core::cmp::Ordering;
Expand Down Expand Up @@ -158,12 +158,7 @@ impl<'a> NtUnicodeStr<'a> {
///
/// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
pub fn try_from_u16(buffer: &'a [u16]) -> Result<Self> {
let elements = buffer.len();
let length_usize = elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let length =
u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;
let length = Self::try_length_from_u16(buffer)?;

Ok(Self {
raw: RawNtString {
Expand Down Expand Up @@ -191,9 +186,48 @@ impl<'a> NtUnicodeStr<'a> {
///
/// [`try_from_u16`]: Self::try_from_u16
pub fn try_from_u16_until_nul(buffer: &'a [u16]) -> Result<Self> {
let length;
let maximum_length;
let (length, maximum_length) = Self::try_length_from_u16_until_nul(buffer)?;

Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_ptr(),
},
_lifetime: PhantomData,
})
}

pub(crate) fn try_length_from_u16(buffer: &[u16]) -> Result<u16> {
let elements = buffer.len();
let length_usize = elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let length =
u16::try_from(length_usize).map_err(|_| NtStringError::BufferSizeExceedsU16)?;

Ok(length)
}

pub(crate) fn try_length_from_u16_cstr(u16cstr: &U16CStr) -> Result<(u16, u16)> {
let buffer = u16cstr.as_slice_with_nul();

// Include the terminating NUL character in `maximum_length` ...
let maximum_length_in_elements = buffer.len();
let maximum_length_in_bytes = maximum_length_in_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length = u16::try_from(maximum_length_in_bytes)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
let length = maximum_length - mem::size_of::<u16>() as u16;

Ok((length, maximum_length))
}

pub(crate) fn try_length_from_u16_until_nul(buffer: &[u16]) -> Result<(u16, u16)> {
match buffer.iter().position(|x| *x == 0) {
Some(nul_pos) => {
// Include the terminating NUL character in `maximum_length` ...
Expand All @@ -203,23 +237,16 @@ impl<'a> NtUnicodeStr<'a> {
let maximum_length_usize = maximum_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
maximum_length = u16::try_from(maximum_length_usize)
let maximum_length = u16::try_from(maximum_length_usize)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
length = maximum_length - mem::size_of::<u16>() as u16;
}
None => return Err(NtStringError::NulNotFound),
};
let length = maximum_length - mem::size_of::<u16>() as u16;

Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_ptr(),
},
_lifetime: PhantomData,
})
Ok((length, maximum_length))
}
None => Err(NtStringError::NulNotFound),
}
}

pub(crate) fn u16_iter(&'a self) -> Copied<Iter<'a, u16>> {
Expand Down Expand Up @@ -314,25 +341,13 @@ impl<'a> TryFrom<&'a U16CStr> for NtUnicodeStr<'a> {
/// The internal buffer will be NUL-terminated.
/// See the [module-level documentation](super) for the implications of that.
fn try_from(value: &'a U16CStr) -> Result<Self> {
let buffer = value.as_slice_with_nul();

// Include the terminating NUL character in `maximum_length` ...
let maximum_length_in_elements = buffer.len();
let maximum_length_in_bytes = maximum_length_in_elements
.checked_mul(mem::size_of::<u16>())
.ok_or(NtStringError::BufferSizeExceedsU16)?;
let maximum_length = u16::try_from(maximum_length_in_bytes)
.map_err(|_| NtStringError::BufferSizeExceedsU16)?;

// ... but not in `length`
debug_assert!(maximum_length >= mem::size_of::<u16>() as u16);
let length = maximum_length - mem::size_of::<u16>() as u16;
let (length, maximum_length) = Self::try_length_from_u16_cstr(value)?;

Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_ptr(),
buffer: value.as_ptr(),
},
_lifetime: PhantomData,
})
Expand Down
50 changes: 28 additions & 22 deletions src/unicode_string/strmut.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Colin Finck <colin@reactos.org>
// Copyright 2023-2026 Colin Finck <colin@reactos.org>
// SPDX-License-Identifier: MIT OR Apache-2.0

use core::cmp::Ordering;
Expand Down Expand Up @@ -122,14 +122,16 @@ impl<'a> NtUnicodeStrMut<'a> {
///
/// [`try_from_u16_until_nul`]: Self::try_from_u16_until_nul
pub fn try_from_u16(buffer: &mut [u16]) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from_u16(buffer)?;
let length = NtUnicodeStr::try_length_from_u16(buffer)?;

// SAFETY: `unicode_str` was created from a mutable `buffer` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length: length,
buffer: buffer.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}

/// Creates an [`NtUnicodeStrMut`] from an existing [`u16`] string buffer that contains at least one NUL character.
Expand All @@ -148,14 +150,16 @@ impl<'a> NtUnicodeStrMut<'a> {
///
/// [`try_from_u16`]: Self::try_from_u16
pub fn try_from_u16_until_nul(buffer: &mut [u16]) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from_u16_until_nul(buffer)?;
let (length, maximum_length) = NtUnicodeStr::try_length_from_u16_until_nul(buffer)?;

// SAFETY: `unicode_str` was created from a mutable `buffer` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: buffer.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}
}

Expand Down Expand Up @@ -192,14 +196,16 @@ impl<'a> TryFrom<&'a mut U16CStr> for NtUnicodeStrMut<'a> {
/// The internal buffer will be NUL-terminated.
/// See the [module-level documentation](super) for the implications of that.
fn try_from(value: &'a mut U16CStr) -> Result<Self> {
let unicode_str = NtUnicodeStr::try_from(&*value)?;
let (length, maximum_length) = NtUnicodeStr::try_length_from_u16_cstr(value)?;

// SAFETY: `unicode_str` was created from a mutable `value` and
// `NtUnicodeStr` and `NtUnicodeStrMut` have the same memory layout,
// so we can safely transmute `NtUnicodeStr` to `NtUnicodeStrMut`.
let unicode_str_mut = unsafe { mem::transmute(unicode_str) };

Ok(unicode_str_mut)
Ok(Self {
raw: RawNtString {
length,
maximum_length,
buffer: value.as_mut_ptr(),
},
_lifetime: PhantomData,
})
}
}

Expand Down
Loading