from functools import lru_cache

import numpy as np
import torch

try:
    import triton
    import triton.language as tl
except ImportError:
    raise RuntimeError("triton import failed; try `pip install --pre triton`")


@triton.jit
def dtw_kernel(
    cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
):
    offsets = tl.arange(0, BLOCK_SIZE)
    mask = offsets < M

    for k in range(1, N + M + 1):  # k = i + j
        tl.debug_barrier()

        p0 = cost + (k - 1) * cost_stride
        p1 = cost + k * cost_stride
        p2 = cost + k * cost_stride + 1

        c0 = tl.load(p0 + offsets, mask=mask)
        c1 = tl.load(p1 + offsets, mask=mask)
        c2 = tl.load(p2 + offsets, mask=mask)

        x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
        cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)

        cost_ptr = cost + (k + 1) * cost_stride + 1
        tl.store(cost_ptr + offsets, cost_row, mask=mask)

        trace_ptr = trace + (k + 1) * trace_stride + 1
        tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
        tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
        tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))


@lru_cache(maxsize=None)
def median_kernel(filter_width: int):
    @triton.jit
    def kernel(
        y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
    ):  # x.shape[-1] == filter_width
        row_idx = tl.program_id(0)
        offsets = tl.arange(0, BLOCK_SIZE)
        mask = offsets < y_stride

        x_ptr = x + row_idx * x_stride  # noqa: F841
        y_ptr = y + row_idx * y_stride

        LOAD_ALL_ROWS_HERE  # noqa: F821

        BUBBLESORT_HERE  # noqa: F821

        tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask)  # noqa: F821

    kernel = triton.JITFunction(kernel.fn)
    kernel.src = kernel.src.replace(
        "    LOAD_ALL_ROWS_HERE",
        "\n".join(
            [
                f"    row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
                for i in range(filter_width)
            ]
        ),
    )
    kernel.src = kernel.src.replace(
        "    BUBBLESORT_HERE",
        "\n\n".join(
            [
                "\n\n".join(
                    [
                        "\n".join(
                            [
                                f"    smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
                                f"    larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
                                f"    row{j} = smaller",
                                f"    row{j + 1} = larger",
                            ]
                        )
                        for j in range(filter_width - i - 1)
                    ]
                )
                for i in range(filter_width // 2 + 1)
            ]
        ),
    )
    kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")

    return kernel


def median_filter_cuda(x: torch.Tensor, filter_width: int):
    """Apply a median filter of given width along the last dimension of x"""
    slices = x.contiguous().unfold(-1, filter_width, 1)
    grid = np.prod(slices.shape[:-2])

    kernel = median_kernel(filter_width)
    y = torch.empty_like(slices[..., 0])

    BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
    kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)

    return y
