Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

* Updated `searchsorted` implementations to align with the 2025.12 array API spec [gh-2902](https://github.com/IntelPython/dpnp/pull/2902)

### Deprecated

### Removed
Expand Down
12 changes: 5 additions & 7 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,13 @@ def searchsorted(a, v, side="left", sorter=None):

"""

usm_a = dpnp.get_usm_ndarray(a)
if dpnp.isscalar(v):
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
else:
usm_v = dpnp.get_usm_ndarray(v)
a = dpnp.get_usm_ndarray(a)
if not dpnp.isscalar(v):
v = dpnp.get_usm_ndarray(v)

usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
return dpnp_array._create_from_usm_ndarray(
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
dpt.searchsorted(a, v, side=side, sorter=sorter)
)


Expand Down
100 changes: 59 additions & 41 deletions dpnp/tensor/_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,41 @@
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

from typing import Literal

from typing import Literal, Union

import dpctl
import dpctl.utils as du

import dpnp.tensor as dpt

from ._compute_follows_data import (
ExecutionPlacementError,
get_coerced_usm_type,
get_execution_queue,
)
from ._copy_utils import _empty_like_orderK
from ._ctors import empty
from ._ctors import empty_like
from ._scalar_utils import _get_dtype, _get_queue_usm_type, _validate_dtype
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
from ._tensor_impl import _take as ti_take
from ._tensor_impl import (
default_device_index_type as ti_default_device_index_type,
)
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
from ._type_utils import isdtype, result_type
from ._type_utils import (
_resolve_weak_types_all_py_ints,
_to_device_supported_dtype,
isdtype,
)
from ._usmarray import usm_ndarray


def searchsorted(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind to populate the changelog also?

x1: usm_ndarray,
x2: usm_ndarray,
x2: usm_ndarray | int | float | complex | bool,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Union[usm_ndarray, None] = None,
sorter: usm_ndarray | None = None,
) -> usm_ndarray:
"""searchsorted(x1, x2, side='left', sorter=None)

Expand All @@ -68,8 +73,8 @@ def searchsorted(
input array. Must be a one-dimensional array. If `sorter` is
`None`, must be sorted in ascending order; otherwise, `sorter` must
be an array of indices that sort `x1` in ascending order.
x2 (usm_ndarray):
array containing search values.
x2 (usm_ndarray | int | float | complex | bool):
search value or values.
side (Literal["left", "right]):
argument controlling which index is returned if a value lands
exactly on an edge. If `x2` is an array of rank `N` where
Expand All @@ -85,13 +90,11 @@ def searchsorted(
array of indices that sort `x1` in ascending order. The array must
have the same shape as `x1` and have an integral data type.
Out of bound index values of `sorter` array are treated using
`"wrap"` mode documented in :py:func:`dpctl.tensor.take`.
`"wrap"` mode documented in :py:func:`dpnp.tensor.take`.
Default: `None`.
"""
if not isinstance(x1, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x1)}")
if not isinstance(x2, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(x2)}")
if sorter is not None and not isinstance(sorter, usm_ndarray):
raise TypeError(f"Expected dpnp.tensor.usm_ndarray, got {type(sorter)}")

Expand All @@ -101,27 +104,43 @@ def searchsorted(
"Expected either 'left' or 'right'"
)

if sorter is None:
q = get_execution_queue([x1.sycl_queue, x2.sycl_queue])
else:
q = get_execution_queue(
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
)
q1, x1_usm_type = x1.sycl_queue, x1.usm_type
q2, x2_usm_type = _get_queue_usm_type(x2)
q3 = sorter.sycl_queue if sorter is not None else None
q = get_execution_queue(tuple(q for q in (q1, q2, q3) if q is not None))
if q is None:
raise ExecutionPlacementError(
"Execution placement can not be unambiguously "
"inferred from input arguments."
)

res_usm_type = get_coerced_usm_type(
tuple(
ut
for ut in (
x1_usm_type,
x2_usm_type,
)
if ut is not None
)
)
dpt.validate_usm_type(res_usm_type, allow_none=False)
sycl_dev = q.sycl_device

if x1.ndim != 1:
raise ValueError("First argument array must be one-dimensional")

x1_dt = x1.dtype
x2_dt = x2.dtype
x2_dt = _get_dtype(x2, sycl_dev)
if not _validate_dtype(x2_dt):
raise ValueError(
Comment thread
ndgrigorian marked this conversation as resolved.
"dpt.searchsorted search value argument has "
f"unsupported data type {x2_dt}"
)

_manager = du.SequentialOrderManager[q]
dep_evs = _manager.submitted_events
ev = dpctl.SyclEvent()
x1_deps = dep_evs
if sorter is not None:
if not isdtype(sorter.dtype, "integral"):
raise ValueError(
Expand All @@ -132,7 +151,7 @@ def searchsorted(
"Sorter array must be one-dimension with the same "
"shape as the first argument array"
)
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
res = empty_like(x1)
ind = (sorter,)
axis = 0
wrap_out_of_bound_indices_mode = 0
Expand All @@ -146,31 +165,30 @@ def searchsorted(
depends=dep_evs,
)
x1 = res
x1_deps = [ev]
_manager.add_event_pair(ht_ev, ev)

if x1_dt != x2_dt:
dt = result_type(x1, x2)
if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf
if x2_dt != dt:
x2_buf = _empty_like_orderK(x2, dt)
dep_evs = _manager.submitted_events
ht_ev, ev = ti_copy(
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
)
_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf
dt1, dt2 = _resolve_weak_types_all_py_ints(x1_dt, x2_dt, sycl_dev)
dt = _to_device_supported_dtype(dpt.result_type(dt1, dt2), sycl_dev)

if x1_dt != dt:
x1_buf = _empty_like_orderK(x1, dt)
# get the submitted events again to ensure the copy waits take call
ht_ev, ev = ti_copy(src=x1, dst=x1_buf, sycl_queue=q, depends=x1_deps)
_manager.add_event_pair(ht_ev, ev)
x1 = x1_buf

if not isinstance(x2, usm_ndarray):
x2 = dpt.asarray(x2, dtype=dt, usm_type=res_usm_type, sycl_queue=q)
elif x2_dt != dt:
x2_buf = _empty_like_orderK(x2, dt)
ht_ev, ev = ti_copy(src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs)
Comment thread
ndgrigorian marked this conversation as resolved.
_manager.add_event_pair(ht_ev, ev)
x2 = x2_buf

dst_usm_type = get_coerced_usm_type([x1.usm_type, x2.usm_type])
index_dt = ti_default_device_index_type(q)

dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
dst = _empty_like_orderK(x2, index_dt, usm_type=res_usm_type)

dep_evs = _manager.submitted_events
if side == "left":
Expand Down
104 changes: 61 additions & 43 deletions dpnp/tests/tensor/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
# THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

import ctypes

import dpctl
import numpy as np
import pytest
Expand All @@ -37,6 +39,30 @@
skip_if_dtype_not_supported,
)

_integer_dtypes = [
"i1",
"u1",
"i2",
"u2",
"i4",
"u4",
"i8",
"u8",
]

_floating_dtypes = [
"f2",
"f4",
"f8",
]

_complex_dtypes = [
"c8",
"c16",
]

_all_dtypes = ["?"] + _integer_dtypes + _floating_dtypes + _complex_dtypes


def _check(hay_stack, needles, needles_np):
assert hay_stack.dtype == needles.dtype
Expand Down Expand Up @@ -103,19 +129,7 @@ def test_searchsorted_strided_bool():
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_contig_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -135,19 +149,7 @@ def test_searchsorted_contig_int(idt):
)


@pytest.mark.parametrize(
"idt",
[
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
],
)
@pytest.mark.parametrize("idt", _integer_dtypes)
def test_searchsorted_strided_int(idt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
Expand All @@ -174,12 +176,12 @@ def _add_extended_fp(array):
array[-1] = dpt.nan


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_contig_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_contig_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
_add_extended_fp(hay_stack)
Expand All @@ -195,12 +197,12 @@ def test_searchsorted_contig_fp(idt):
)


@pytest.mark.parametrize("idt", [dpt.float16, dpt.float32, dpt.float64])
def test_searchsorted_strided_fp(idt):
@pytest.mark.parametrize("fdt", _floating_dtypes)
def test_searchsorted_strided_fp(fdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(fdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(fdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -243,12 +245,12 @@ def _add_extended_cfp(array):
return dpt.sort(dpt.concat((ev, array)))


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_contig_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_contig_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True)
hay_stack = _add_extended_cfp(hay_stack)
Expand All @@ -263,12 +265,12 @@ def test_searchsorted_contig_cfp(idt):
)


@pytest.mark.parametrize("idt", [dpt.complex64, dpt.complex128])
def test_searchsorted_strided_cfp(idt):
@pytest.mark.parametrize("cdt", _complex_dtypes)
def test_searchsorted_strided_cfp(cdt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(idt, q)
skip_if_dtype_not_supported(cdt, q)

dt = dpt.dtype(idt)
dt = dpt.dtype(cdt)

hay_stack = dpt.repeat(
dpt.linspace(0, 1, num=255, dtype=dt, endpoint=True), 4
Expand Down Expand Up @@ -315,7 +317,7 @@ def test_searchsorted_validation():
x1 = dpt.arange(10, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("Default device could not be created")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.searchsorted(x1, None)
with pytest.raises(TypeError):
dpt.searchsorted(x1, x1, sorter=dict())
Expand Down Expand Up @@ -405,3 +407,19 @@ def test_searchsorted_strided_scalar_needle():
needles = dpt.asarray(needles_np)

_check(hay_stack, needles, needles_np)


@pytest.mark.parametrize(
"py_zero",
[bool(0), int(0), float(0), complex(0), np.float32(0), ctypes.c_int(0)],
)
@pytest.mark.parametrize("dt", _all_dtypes)
def test_searchsorted_py_scalars(py_zero, dt):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dt, q)

x = dpt.zeros(10, dtype=dt, sycl_queue=q)

r1 = dpt.searchsorted(x, py_zero)
assert isinstance(r1, dpt.usm_ndarray)
assert r1.shape == ()
Loading