# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for lowering JAX to Mosaic-compatible MLIR dialects."""
from __future__ import annotations

from collections.abc import Callable, Sequence
import contextlib
import dataclasses
import functools
import string
from typing import Any, Hashable

import jax
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import checkify
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import prng
from jax._src import source_info_util
from jax._src import state
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import math
from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import error_handling
from jax._src.pallas.mosaic import primitives as tpu_primitives
from jax._src.pallas.mosaic import random as pl_random
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
from jax._src.state.types import RefBitcaster, RefReshaper
from jax._src.state.utils import dtype_bitwidth
from jax._src.typing import Array, DTypeLike
from jax._src.util import safe_map
from jax._src.util import safe_zip
from jax._src.util import split_list
from jax._src.util import unzip2
from jax.experimental.mosaic.dialects import tpu
import jax.numpy as jnp
from jaxlib.mlir.ir import Module
import numpy as np

# TODO(sharadmv): enable type checking
# mypy: ignore-errors

NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# Booleans are stored as the following type in memrefs.
BOOL_MEMREF_TYPE = np.dtype('int32')

# The value interpreted as a dynamic dimension by MLIR.
MLIR_DYNAMIC = -9223372036854775808

partial = functools.partial
map, unsafe_map = safe_map, map  # pylint: disable=redefined-builtin
zip, unsafe_zip = safe_zip, zip  # pylint: disable=redefined-builtin


@dataclasses.dataclass
class MeshContext:
  mesh_shape: tuple[int, ...]
  axis_names: tuple[str, ...]
  mesh_strides: tuple[int, ...]


@dataclasses.dataclass
class LoweringContext:
  ir_context: ir.Context
  grid_sizes: tuple[int, ...]  # Includes both user and vmap axes.
  grid_names: tuple[Hashable, ...] | None
  mapped_dims: tuple[int, ...]  # Indices of vmapped grid dimensions.
  user_grid_indices: Sequence[ir.Value] | None
  block_shapes: list[tuple[int | pallas_core.Mapped, ...]]
  name_stack: source_info_util.NameStack
  mesh_context: MeshContext | None
  replace = dataclasses.replace
  traceback_caches: mlir.TracebackCaches
  for_verification: bool

  @property
  def grid_rank(self):
    return len(self.grid_sizes)

  @contextlib.contextmanager
  def grid_name_context(self):
    # TODO(b/355036977): generalize this across other platforms
    if not self.grid_names:
      yield
      return
    grid_names = self.grid_names
    valid_grid_sizes = tuple(
        d for i, d in enumerate(self.grid_sizes) if i not in self.mapped_dims
    )
    grid_env = zip(grid_names, valid_grid_sizes)
    with jax_core.extend_axis_env_nd(grid_env):
      yield


@dataclasses.dataclass
class LoweringRuleContext:
  lowering_context: LoweringContext
  avals_in: Sequence[jax_core.AbstractValue]
  avals_out: Sequence[jax_core.AbstractValue]
  block_shapes: Sequence[tuple[int | pallas_core.Mapped, ...] | None]

  replace = dataclasses.replace


def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None
                                     ) -> TPUMemorySpace:
  match memory_space:
    case None:
      # We pick VMEM as the default one when no memory space is
      # specified
      return TPUMemorySpace.VMEM
    case pallas_core.MemorySpace.ANY:
      # Map the general ANY memory space to TPU ANY memory space
      return TPUMemorySpace.ANY
    case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX:
      return TPUMemorySpace.SMEM
    case TPUMemorySpace():
      # Leave the memory space unchanged
      return memory_space
    case _:
      raise ValueError("Invalid memory space: {memory_space}")


def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None
                                      ) -> ir.Attribute:
  tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space)
  return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>")

def _dtype_to_ir_type(dtype: jnp.dtype,
                      is_kernel_boundary: bool = False) -> ir.Type:
  if jnp.issubdtype(dtype, tpu_core.semaphore_dtype):
    if jnp.issubdtype(dtype, tpu_core.dma_semaphore):
      return ir.Type.parse("!tpu.dma_semaphore")
    elif jnp.issubdtype(dtype, tpu_core.semaphore):
      return ir.Type.parse("!tpu.semaphore")
    elif jnp.issubdtype(dtype, tpu_core.barrier_semaphore):
      return ir.Type.parse("!tpu.semaphore")
    else:
      raise NotImplementedError
  if is_kernel_boundary and jnp.issubdtype(dtype, jnp.dtype('bool')):
    dtype = BOOL_MEMREF_TYPE
  # TODO(justinfu): Remove after mosaic supports unsigned types.
  # This conversion makes mosaic interpret all unsigned types as signed types.
  type =  mlir.dtype_to_ir_type(dtype)
  if isinstance(type, ir.IntegerType):
    return ir.IntegerType.get_signless(type.width)
  else:
    return type

def aval_to_ir_type(aval,
                    shape=None,
                    memory_space: MemorySpace | None = None,
                    is_kernel_boundary: bool = False):
  if isinstance(aval, tpu_core.AbstractSemaphore):
    if aval.sem_type is tpu_core.SemaphoreType.DMA:
      sem_type = ir.Type.parse("!tpu.dma_semaphore")
    elif aval.sem_type is tpu_core.SemaphoreType.REGULAR:
      sem_type = ir.Type.parse("!tpu.semaphore")
    elif aval.sem_type is tpu_core.SemaphoreType.BARRIER:
      sem_type = ir.Type.parse("!tpu.semaphore")
    else:
      raise ValueError(f"Cannot allocate {aval.sem_type}.")
    memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE)
    return ir.MemRefType.get((), sem_type, memory_space=memspace)
  if dtypes.issubdtype(aval.dtype, dtypes.prng_key):
    shape = aval.dtype._impl.key_shape
    if pl_random.is_pallas_impl(aval.dtype._impl):
      if memory_space is None:
        memory_space = TPUMemorySpace.SMEM
      if memory_space != TPUMemorySpace.SMEM:
        raise ValueError(
            f"PRNG keys must be stored in SMEM. Got {memory_space}"
        )
    memspace = _memory_space_to_mosaic_attribute(memory_space)
    return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)),
                             memory_space=memspace)
  if isinstance(aval, state.AbstractRef):
    if shape is None:
      shape = aval.shape
    memspace = _memory_space_to_mosaic_attribute(memory_space)
    return ir.MemRefType.get(shape,
      _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
      memory_space=memspace)
  if isinstance(aval, jax_core.ShapedArray):
    if shape is None:
      shape = aval.shape
    if not shape:
      return _dtype_to_ir_type(
          aval.dtype, is_kernel_boundary=is_kernel_boundary)
    return ir.VectorType.get(
        shape,
        _dtype_to_ir_type(aval.dtype, is_kernel_boundary=is_kernel_boundary))
  raise NotImplementedError(aval)


def ir_constant(x, mlir_type=None):
  if not hasattr(x, "dtype"):
    if isinstance(x, int):
      x = np.array(x, np.int32)
    elif isinstance(x, float):
      x = np.array(x, np.float32)
  if not mlir_type:
    mlir_type = _dtype_to_ir_type(x.dtype)
  if isinstance(x, int) or x.dtype in (np.int32, np.uint32, np.int8):
    return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x)))
  elif isinstance(x, float) or x.dtype == np.float32:
    return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
  elif x.dtype == jnp.bfloat16:
    return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
  elif x.dtype == jnp.bool_:
    return arith.constant(mlir_type, ir.BoolAttr.get(bool(x)))
  raise NotImplementedError(x.dtype)


lowering_rules = {}
skip_mlir_conversions = set()

def _get_aval_physical_dtype_shape(aval):
  dtype_physical_shape = jax_core.physical_aval(aval).shape[
      len(aval.shape) :
  ]
  return dtype_physical_shape

def _get_arg_type(
    aval,
    block_mapping: pallas_core.BlockMapping | None,
):
  memory_space = None
  if isinstance(aval, pallas_core.AbstractMemoryRef):
    memory_space = aval.memory_space
    # We assume unannotated memory refs are in VMEM
    if memory_space is None:
      memory_space = TPUMemorySpace.VMEM
  if isinstance(aval, tpu_core.AbstractSemaphore):
    return aval_to_ir_type(aval), None
  # TODO(necula): clean this None block_mapping
  if block_mapping is None:
    return aval_to_ir_type(aval, memory_space=memory_space), aval.shape
  shape = tuple(1 if b is pallas_core.mapped else b for b in block_mapping.block_shape)
  return (
      aval_to_ir_type(aval, shape=shape, memory_space=memory_space),
      block_mapping.block_shape,
  )


@dataclasses.dataclass(init=False)
class MosaicGridMapping:
  grid: tuple[int, ...] | None
  grid_names: tuple[Hashable, ...] | None
  jaxpr: jax_core.Jaxpr
  block_mappings: tuple[pallas_core.BlockMapping | None, ...]
  mapped_dims: tuple[int, ...]
  scalar_prefetch_types: tuple[ir.Type, ...]
  operand_types: tuple[ir.Type, ...]
  scratch_types: tuple[ir.Type, ...]
  grid_types: tuple[ir.Type, ...]
  scalar_prefetch_block_shapes: tuple[tuple[int, ...], ...]
  operand_block_shapes: tuple[tuple[int, ...], ...]
  scratch_block_shapes: tuple[tuple[int, ...], ...]
  mesh_info: MeshInfo | None
  get_grid_indices: Callable | None

  def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping,
               dimension_semantics: tuple[str, ...] | None,
               mesh: mesh_lib.Mesh | None):
    self.grid = grid_mapping.grid
    self.grid_names = grid_mapping.grid_names
    self.jaxpr = jaxpr
    self.block_mappings = grid_mapping.block_mappings
    self.mapped_dims = grid_mapping.vmapped_dims
    # TODO(mvoz): Generalize to not need this
    user_grid = tuple(
        g for i, g in enumerate(self.grid) if i not in self.mapped_dims
    )
    if dimension_semantics is None:
      dimension_semantics = ("arbitrary",) * len(user_grid)
    if len(user_grid) != len(dimension_semantics):
      raise ValueError(
          "Must have dimension semantics for each dimension of the grid."
      )
    assert len(self.mapped_dims) + len(dimension_semantics) == len(
        self.grid
    ), (
        f"Misconfigured grid: {self.mapped_dims=}, {dimension_semantics=},"
        f" {self.grid=}"
    )
    # dimension_semantics is user provided and won't take into account vmap
    # dimensions. Here we add in parallel dimensions for the vmaps.
    semantics_iter = iter(dimension_semantics)
    self._dimension_semantics = tuple(
        next(semantics_iter) if i not in self.mapped_dims else "parallel"
        for i in range(len(self.grid))
    )

    in_avals = [invar.aval for invar in self.jaxpr.invars]
    # jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
    scalar_prefetch_avals = in_avals[grid_mapping.slice_index_ops]
    operand_avals = in_avals[grid_mapping.slice_block_ops]
    scratch_avals = in_avals[grid_mapping.slice_scratch_ops]
    self.scalar_prefetch_types, _ = unzip2([
        _get_arg_type(aval, None)
        for aval in scalar_prefetch_avals])
    self.scalar_prefetch_block_shapes = tuple(
        aval.shape for aval in scalar_prefetch_avals)
    self.operand_types, self.operand_block_shapes = unzip2([
        _get_arg_type(aval, block_mapping)
        for aval, block_mapping in zip(operand_avals, self.block_mappings)])
    self.scratch_types, _ = unzip2([
        _get_arg_type(aval, None) for aval in scratch_avals])
    self.scratch_block_shapes = tuple(
        aval.shape if not isinstance(aval, tpu_core.AbstractSemaphore) else None
        for aval in scratch_avals
    )
    self.grid_types, _ = unzip2([
        _get_arg_type(pallas_core.index_map_grid_aval, None)
        for _ in range(len(self.grid))
    ])
    self._prepare_mesh_info(mesh)

    if grid_mapping.get_grid_indices is None:

      # Avoid using self.mapped_dims within the function, since doing so will
      # introduce a self->_get_grid_indices->self reference cycle that means
      # MosaicGridMapping instances can only ever be deleted by GC, rather than
      # by their reference counts going to 0.
      mapped_dims = self.mapped_dims
      def _get_grid_indices(indices, maybe_include_mapped_dims: bool):
        if maybe_include_mapped_dims:
          return indices
        return tuple(
            idx for i, idx in enumerate(indices) if i not in mapped_dims
        )

      self.get_grid_indices = _get_grid_indices
    else:
      self.get_grid_indices = grid_mapping.get_grid_indices

  def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None):
    if not self.has_communication:
      self.mesh_info = None
      return
    if mesh is None:
      raise ValueError(
          "Cannot use communication in pallas_call without shard_map."
      )
    axis_names = mesh.axis_names
    if self.grid_names is not None:
      if any(a in self.grid_names for a in axis_names):
        raise ValueError(
            "Cannot shadow axis mesh axis names with grid names. mesh axis"
            f" names: {mesh.axis_names}, grid names: {self.grid_names}"
        )
    # We need mesh <-> logical translation tables. Since the logical IDs are
    # just linearized versions of the mesh IDs, we create those tables.
    mesh_strides = pallas_utils.strides_from_shape(tuple(
        mesh.shape[a] for a in axis_names
    ))
    mesh_shape = tuple(mesh.shape.values())
    self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides)

  def maybe_compress_grid(self):
    # If we have many leading parallel dimensions, we should "compress" them
    # into one so we can load balance across cores as best as we can.
    # TODO(sharadmv): implement this optimization
    pass

  @functools.cached_property
  def has_communication(self) -> bool:
    nonlocal_axis_names = set()
    def _get_nonlocal_axis_names(jaxpr: jax_core.Jaxpr):
      return {
          e.name
          for e in jaxpr.effects
          if isinstance(e, jax_core.NamedAxisEffect)
          and (not self.grid_names or e.name not in self.grid_names)
      }
    nonlocal_axis_names.update(_get_nonlocal_axis_names(self.jaxpr))
    for bm in self.block_mappings:
      if bm is not None:
        nonlocal_axis_names.update(_get_nonlocal_axis_names(bm.index_map_jaxpr))
    return bool(nonlocal_axis_names)

  def get_extra_args(self) -> tuple[Any, ...]:
    return ()

  def get_dimension_semantics(self) -> ir.ArrayAttr:

    def _get_semantics(s: str | None) -> str:
      if s is None:
        return "#tpu.dimension_semantics<arbitrary>"
      return f"#tpu.dimension_semantics<{s}>"

    return ir.ArrayAttr.get(
        map(
            ir.Attribute.parse,
            map(_get_semantics, self._dimension_semantics),
        )
    )

@dataclasses.dataclass
class MeshInfo:
  mesh_shape: tuple[int, ...]
  axis_names: list[str]
  mesh_strides: tuple[int, ...]


def _check_block_mappings(
    block_mappings: tuple[pallas_core.BlockMapping, ...],
    lowering_context: mlir.LoweringRuleContext,
    name_and_src_info: pallas_core.NameAndSrcInfo,
) -> None:
  del lowering_context  # originally needed for forward compat
  for bm in block_mappings:
    rank = len(bm.block_shape)
    # TODO(necula): add tests for SMEM blocks with trivial windowing
    # We support scalars too
    if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and
        bm.has_trivial_window()):
      continue
    if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE:
      continue

    def err_details():
      return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} "
              "has block shape "
              f"{bm.block_shape}, array shape {bm.array_shape_dtype.shape}, "
              # TODO(necula): add index_map source location info
              f"and index_map {bm.index_map_jaxpr.jaxpr}, in "
              f"memory space {bm.block_aval.memory_space}."
              "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec")
    if rank < 1:
      raise ValueError(
          "The Pallas TPU lowering currently supports only blocks of "
          "rank >= 1. " + err_details())

    if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and
        not bm.has_trivial_window()):
      raise ValueError(
          "The Pallas TPU lowering currently supports in memory space ANY "
          "only blocks having the same block shape as the array shape "
          "and a trivial index_map (returning all 0s)." + err_details())

    unmapped_bs = [
        1 if bs is pallas_core.mapped else bs for bs in bm.block_shape]
    bs0, as0 = unmapped_bs[-1], bm.array_shape_dtype.shape[-1]
    if rank >= 2:
      bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
    else:
      bs1, as1 = 1, 1

    if rank >= 2:
      evenly_divisible = (
          (bs0 == as0 or bs0 % 128 == 0) and
          (bs1 == as1 or bs1 % 8 == 0)
      )
      if not evenly_divisible:
        raise ValueError(
            "The Pallas TPU lowering currently requires that the last two "
            "dimensions of your block shape are divisible by 8 and 128 "
            "respectively, or be equal to the respective dimensions of the "
            "overall array. "
            + err_details()
        )
    else:
      assert rank == 1
      # bools get a bitwidth of 32 due to how mosaic handles them
      if bm.array_shape_dtype.dtype == jnp.bool_:
        bitwidth = 32
      else:
        bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype)
      packing = 32 // bitwidth
      tiling_size = 128 * packing
      evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
      if not evenly_divisible:
        raise ValueError(
            "The Pallas TPU lowering currently requires that rank 1 block"
            " shapes, either 1) the first (and only) dimension of the block"
            " shape is equal to the first (and only) dimension of the array"
            " shape, or 2) the first (and only) dimension of the block shape"
            f" is a multiple of the tiling size ({tiling_size} = 128 * (32 //"
            f" {lax_internal._bit_width(bm.array_shape_dtype.dtype)})) of the"
            " array shape. "
            + err_details()
        )


def lower_jaxpr_to_module(
    lowering_context: mlir.LoweringRuleContext,
    ctx: ir.Context,
    grid_mapping: pallas_core.GridMapping,
    jaxpr: jax_core.Jaxpr,
    *,
    dimension_semantics: tuple[str | None, ...] | None,
    name_and_src_info: pallas_core.NameAndSrcInfo,
    mesh: mesh_lib.Mesh | None = None,
    for_verification: bool = False,
) -> tuple[Module, tuple[Any, ...]]:
  # Verify that we have legal block mappings to catch errors early.
  _check_block_mappings(grid_mapping.block_mappings, lowering_context,
                        name_and_src_info)

  mosaic_grid_mapping = MosaicGridMapping(
      jaxpr, grid_mapping, dimension_semantics, mesh)
  mosaic_grid_mapping.maybe_compress_grid()
  m = ir.Module.create()
  attrs = m.operation.attributes
  module_name = name_and_src_info.name
  attrs["sym_name"] = ir.StringAttr.get(module_name)
  sym_tab = ir.SymbolTable(m.operation)
  func_op = lower_jaxpr_to_func(
      ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
      name="main", for_verification=for_verification,
  )
  m.body.append(func_op)
  sym_tab.insert(func_op)
  window_params = []
  grid = mosaic_grid_mapping.grid
  if grid:
    for i, bm in enumerate(grid_mapping.block_mappings):
      func_name = f"transform_{i}"
      # ANY and SEMAPHORE operands don't support windowing and require empty window_params.
      tpu_memory_space = _memory_space_to_tpu_memory_space(
          bm.block_aval.memory_space)
      if (
          tpu_memory_space == tpu_core.TPUMemorySpace.ANY
          or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE
      ):
        # We checked above that the block does not require windowing.
        window_params.append(ir.DictAttr.get())
        continue
      mlir_func = lower_jaxpr_to_transform_func(
          ctx,
          bm.index_map_jaxpr.jaxpr,
          bm.block_aval,
          name=func_name,
          mosaic_grid_mapping=mosaic_grid_mapping,
          for_verification=for_verification,
      )
      assert mlir_func.verify(), mlir_func
      block_shape = [
          1 if b is pallas_core.mapped else b for b in bm.block_shape
      ]
      # If we have an extended dtype, we need to add the block shape for the
      # remaining physical dtype.
      block_shape += list(_get_aval_physical_dtype_shape(bm.block_aval.inner_aval))
      window_shape = ir.DenseI64ArrayAttr.get(block_shape)
      block_params = dict(
          window_bounds=window_shape,
          transform_indices=ir.FlatSymbolRefAttr.get(func_name),
      )
      if isinstance(bm.indexing_mode, pallas_core.Unblocked):
        if bm.indexing_mode.padding is None:
          pad_low = pad_high = [0] * len(bm.block_shape)
        else:
          pad_low, pad_high = map(list, zip(*bm.indexing_mode.padding))
        block_params["window_kind"] = ir.Attribute.parse(
            f"#tpu.element_window<{pad_low},{pad_high}>"
        )
      window_params.append(ir.DictAttr.get(block_params))
      m.body.append(mlir_func)
      sym_tab.insert(mlir_func)
    func_op.attributes["window_params"] = ir.ArrayAttr.get(window_params)
    static_grid = [
        MLIR_DYNAMIC if b is pallas_core.dynamic_grid_dim else b for b in grid
    ]
    func_op.attributes["iteration_bounds"] = ir.DenseI64ArrayAttr.get(static_grid)

  func_op.attributes["scalar_prefetch"] = ir.IntegerAttr.get(
      ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scalar_prefetch_types))
  func_op.attributes["scratch_operands"] = ir.IntegerAttr.get(
      ir.IntegerType.get_signless(64), len(mosaic_grid_mapping.scratch_types))
  func_op.attributes["dimension_semantics"] = (
      mosaic_grid_mapping.get_dimension_semantics()
  )
  return m, mosaic_grid_mapping.get_extra_args()


def lower_jaxpr_to_transform_func(
    ctx: ir.Context,
    jaxpr: jax_core.Jaxpr,
    aval: jax_core.AbstractValue,
    *,
    name: str,
    mosaic_grid_mapping: MosaicGridMapping,
    for_verification: bool,
) -> func.FuncOp:
  num_grid = len(mosaic_grid_mapping.grid_types)
  arg_types = [
      *mosaic_grid_mapping.grid_types,
      *mosaic_grid_mapping.scalar_prefetch_types,
  ]
  def body_func(*args):
    grid_indices, scalar_prefetch = split_list(args, [num_grid])
    jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
        grid_indices, maybe_include_mapped_dims=True
    )
    arg_block_shapes = [
        *[()] * len(jaxpr_indices),
        *mosaic_grid_mapping.scalar_prefetch_block_shapes,
    ]

    mesh_info = mosaic_grid_mapping.mesh_info
    if mesh_info is not None:
      mesh_context = MeshContext(
          mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
      )
    else:
      mesh_context = None
    lowering_context = LoweringContext(
        ctx,
        mosaic_grid_mapping.grid,
        mosaic_grid_mapping.grid_names,
        mosaic_grid_mapping.mapped_dims,
        None,
        arg_block_shapes,
        source_info_util.NameStack(),
        mesh_context=mesh_context,
        traceback_caches=mlir.TracebackCaches(),
        for_verification=for_verification,
    )
    out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
                        *scalar_prefetch)
    assert isinstance(aval, state.AbstractRef), aval
    # If we have an extended dtype, we need to add 0s for the block indices
    # for the remaining physical dtype.
    out += [
        ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
    ] * len(_get_aval_physical_dtype_shape(aval.inner_aval))
    return out

  body_func.__name__ = name
  body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
  try:
    body.func_op.verify()
  except ir.MLIRError as e:
    raise error_handling.mlir_error_to_verification_error(e) from e
  return body.func_op


def lower_jaxpr_to_func(
    ctx: ir.Context,
    jaxpr: jax_core.Jaxpr,
    *,
    mosaic_grid_mapping: MosaicGridMapping,
    name: str,
    for_verification: bool,
) -> func.FuncOp:
  num_grid = len(mosaic_grid_mapping.grid_types)
  num_scalar_prefetch = len(mosaic_grid_mapping.scalar_prefetch_types)
  arg_types = [
      *mosaic_grid_mapping.grid_types,
      *mosaic_grid_mapping.scalar_prefetch_types,
      *mosaic_grid_mapping.operand_types,
      *mosaic_grid_mapping.scratch_types,
  ]
  arg_block_shapes = [
      *mosaic_grid_mapping.scalar_prefetch_block_shapes,
      *mosaic_grid_mapping.operand_block_shapes,
      *mosaic_grid_mapping.scratch_block_shapes,
  ]
  def body_func(*args):
    grid_indices, scalar_prefetch, operands_and_scratch = split_list(
        args, [num_grid, num_scalar_prefetch])
    jaxpr_indices = mosaic_grid_mapping.get_grid_indices(
        grid_indices, maybe_include_mapped_dims=False
    )
    mesh_info = mosaic_grid_mapping.mesh_info
    if mesh_info is not None:
      mesh_context = MeshContext(
          mesh_info.mesh_shape, mesh_info.axis_names, mesh_info.mesh_strides
      )
    else:
      mesh_context = None
    lowering_context = LoweringContext(
        ctx,
        mosaic_grid_mapping.grid,
        mosaic_grid_mapping.grid_names,
        mosaic_grid_mapping.mapped_dims,
        jaxpr_indices,
        arg_block_shapes,
        source_info_util.NameStack(),
        mesh_context=mesh_context,
        traceback_caches=mlir.TracebackCaches(),
        for_verification=for_verification,
    )
    return jaxpr_subcomp(
        lowering_context, jaxpr, *scalar_prefetch, *operands_and_scratch
    )
  body_func.__name__ = name
  body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
  try:
    body.func_op.verify()
  except ir.MLIRError as e:
    raise error_handling.mlir_error_to_verification_error(e) from e
  return body.func_op


def lower_fun(fun: Callable, *, multiple_results: bool) -> Callable:
  def f_lowered(ctx: LoweringRuleContext, *args, **params):
    f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
    wrapped_fun = lu.wrap_init(f, params)
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
    if consts:
      raise NotImplementedError
    jaxpr = pe.convert_constvars_jaxpr(jaxpr)
    lowering_context = ctx.lowering_context.replace(
        block_shapes=ctx.block_shapes)
    out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
    if not multiple_results:
      return out[0]
    return out

  return f_lowered


class LoweringException(Exception):
  pass


def _compute_name_stack_updates(
    old_name_stack: list[str],
    new_name_stack: list[str]
) -> tuple[list[str], list[str]]:
  """Computes the popped/pushed items to the name stack after an update.

  Args:
    old_name_stack: The name stack prior to the update.
    new_name_stack: The name stack after the update.

  Returns:
    popped: A list of names popped from the name stack as part of the update.
    pushed: A list of names pushed to the name stack as part of the update.
  """
  common_prefix_idx = 0
  for i, (old, new) in enumerate(unsafe_zip(old_name_stack, new_name_stack)):
    if old == new:
      common_prefix_idx = i+1
    else:
      break
  return old_name_stack[common_prefix_idx:], new_name_stack[common_prefix_idx:]


def jaxpr_subcomp(
    ctx: LoweringContext, jaxpr: jax_core.Jaxpr, *args: ir.Value
) -> Sequence[ir.Value]:
  assert not jaxpr.constvars
  env = {}
  block_shape_env = {}

  def read_block_shape(atom: jax_core.Atom):
    if isinstance(atom, jax_core.Literal):
      return None
    return block_shape_env.get(atom, None)

  def read_env(atom: jax_core.Atom):
    return atom.val if isinstance(atom, jax_core.Literal) else env[atom]

  def write_env(var: jax_core.Var, val):
    is_valid_type = isinstance(val, (ir.Value, KeyScalarBundle))
    assert is_valid_type, type(val)
    env[var] = val

  for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
    block_shape_env[invar] = bs
  map(write_env, jaxpr.invars, args)

  initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
  current_name_stack: list[str] = []
  # TODO(justinfu): Handle transform scopes.
  current_name_stack.extend(initial_name_stack)
  for eqn in jaxpr.eqns:
    invals = map(read_env, eqn.invars)
    source_info = eqn.source_info.replace(
        name_stack=ctx.name_stack + eqn.source_info.name_stack
    )
    loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info)
    with source_info_util.user_context(eqn.source_info.traceback), loc:
      if eqn.primitive in lowering_rules:
        if eqn.primitive not in skip_mlir_conversions:
          invals = [_ensure_mlir_value(x, v.aval)
                    for x, v in zip(invals, eqn.invars)]
        block_shapes = map(read_block_shape, eqn.invars)
        rule_context = LoweringRuleContext(
            ctx,
            [v.aval for v in eqn.invars],
            [v.aval for v in eqn.outvars],
            block_shapes,
        )

        # Insert trace_start and trace_stop ops on named_scope boundaries.
        name_stack = [scope.name for scope in source_info.name_stack.stack]
        popped, pushed = _compute_name_stack_updates(
            current_name_stack, name_stack)
        current_name_stack = name_stack
        for _ in popped:
          tpu.TraceStopOp()
        for name in pushed:
          tpu.TraceStartOp(message=name, level=10)

        try:
          ans = lowering_rules[eqn.primitive](
              rule_context, *invals, **eqn.params
          )
        except LoweringException:
          raise  # We only add the extra info to the innermost exception.
        except Exception as e:
          if not pallas_call._verbose_errors_enabled():
            raise
          msg = (f"{type(e).__name__}: {e}\n" +
                "Additional diagnostics: \n" +
                f"Failing jaxpr equation: {eqn}\n")
          new_error = LoweringException(msg)
          # We insert the traceback here so that the user code shows
          # up in the traceback for the post-transform error.
          if source_info.traceback is not None:
            tb = source_info.traceback.as_python_traceback()
            new_error.__traceback__ = traceback_util.filter_traceback(tb)
          raise new_error from e
      else:
        raise NotImplementedError(
            "Unimplemented primitive in Pallas TPU lowering: "
            f"{eqn.primitive.name}. "
            "Please file an issue on https://github.com/jax-ml/jax/issues.")
      if eqn.primitive.multiple_results:
        map(write_env, eqn.outvars, ans)
      else:
        write_env(eqn.outvars[0], ans)

  # Drain the name stack at the end of a jaxpr and insert trace_stop ops.
  popped, pushed = _compute_name_stack_updates(
      current_name_stack, initial_name_stack)
  for _ in popped:
    tpu.TraceStopOp()
  assert len(pushed) == 0

  outvals = map(read_env, jaxpr.outvars)
  outvals = [
      ir_constant(x) if isinstance(var, jax_core.Literal) else x
      for x, var in zip(outvals, jaxpr.outvars)
  ]
  return outvals


def _ensure_mlir_value(val, aval):
  if isinstance(val, ir.Value):
    return val
  if isinstance(val, KeyScalarBundle):
    return val
  elif isinstance(val, (np.generic, np.ndarray, int, float)):
    return ir_constant(val, _dtype_to_ir_type(aval.dtype))
  else:
    raise RuntimeError(
        f"Unsupported argument to a JAX primitive of type: {type(val)}"
    )


def _get_lowering_rule(
    ctx: LoweringRuleContext, ref, *idx, tree,
):
  indexers = tree_util.tree_unflatten(tree, idx)
  indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[1:])
  # Call _load_lowering_rule (since it's more general)
  ref_aval, *_ = ctx.avals_in
  args_flat, args_tree = tree_util.tree_flatten((ref, indexers, None, None))
  avals_flat = tree_util.tree_leaves((ref_aval, indexers_avals, None, None))
  ctx = ctx.replace(
      avals_in=avals_flat,
      block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
  )
  return _load_lowering_rule(ctx, *args_flat, args_tree=args_tree)


lowering_rules[state_primitives.get_p] = _get_lowering_rule
skip_mlir_conversions.add(state_primitives.get_p)


def _swap_lowering_rule(
    ctx: LoweringRuleContext,
    ref,
    val,
    *idx,
    tree
):
  indexers = tree_util.tree_unflatten(tree, idx)
  indexers_avals = tree_util.tree_unflatten(tree, ctx.avals_in[2:])
  # Call _masked_swap_lowering_rule (since it's more general)
  ref_aval, val_aval, *_ = ctx.avals_in
  args_flat, args_tree = tree_util.tree_flatten((ref, indexers, val, None))
  avals_flat = tree_util.tree_leaves(
      (ref_aval, indexers_avals, val_aval, None)
  )
  ctx = ctx.replace(
      avals_in=avals_flat,
      block_shapes=[ctx.block_shapes[0], *[None] * (len(avals_flat) - 1)],
  )
  return _masked_swap_lowering_rule(ctx, *args_flat, args_tree=args_tree)

lowering_rules[state_primitives.swap_p] = _swap_lowering_rule
skip_mlir_conversions.add(state_primitives.swap_p)


def _make_index(s):
  if isinstance(s, (int, np.ndarray)):
    return ir_constant(s, ir.IndexType.get())
  if s.type == ir.IndexType.get():
    return s
  return arith.index_cast(ir.IndexType.get(), s)


def _maybe_cast_to_index(cast_to_index, x):
  if cast_to_index:
    return _make_index(x)
  return _ensure_mlir_value(x, aval=pallas_core.index_map_grid_aval)


def _index_to_start_size_stride(
    idx: tuple[indexing.Slice | int | ir.Value, ...], cast_to_index: bool
) -> tuple[ir.Value, int | ir.Value, int, bool]:
  assert not isinstance(idx, slice)
  if isinstance(idx, indexing.Slice):
    start = _maybe_cast_to_index(cast_to_index, idx.start)
    size = idx.size
    stride = idx.stride
    squeeze = False
  elif isinstance(idx, int):
    start = _maybe_cast_to_index(cast_to_index, idx)
    size = 1
    stride = 1
    squeeze = True
  else:
    if np.shape(idx):
      raise ValueError(f"Can only use ()-shaped and slice indexing: {idx}")
    start = _maybe_cast_to_index(cast_to_index, idx)
    size = 1
    stride = 1
    squeeze = True
  return start, size, stride, squeeze


def _indexer_to_start_size_stride(
    indexer: NDIndexer,
    ref_block_shape: tuple[int | pallas_core.Mapped, ...],
    *,
    cast_to_index: bool,
) -> tuple[
    tuple[ir.Value, ...],
    tuple[int | ir.Value, ...],
    tuple[int, ...],
    tuple[bool, ...],
    tuple[int | pallas_core.Mapped, ...],
]:
  indices_iter = iter(indexer.indices)
  starts, sizes, strides, squeeze_dims = [], [], [], []
  for s in ref_block_shape:
    start, size, stride, squeeze_dim = (
        (
            _maybe_cast_to_index(cast_to_index, 0),
            1,
            1,
            True,
        )
        if s is pallas_core.mapped
        else _index_to_start_size_stride(next(indices_iter), cast_to_index)
    )
    starts.append(start)
    sizes.append(size)
    strides.append(stride)
    squeeze_dims.append(squeeze_dim)
  next_index = next(indices_iter, None)
  assert next_index is None, (indexer.indices, ref_block_shape)
  new_ref_block_shape = tuple(s for s, squeeze in zip(sizes, squeeze_dims)
                              if not squeeze)
  return (
      tuple(starts),
      tuple(sizes),
      tuple(strides),
      tuple(squeeze_dims),
      new_ref_block_shape,
  )


def _slice_memref(
    ref: ir.Value,
    indexer: NDIndexer,
    ref_dtype: DTypeLike,
    ref_block_shape: tuple[int | pallas_core.Mapped, ...],
) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]:
  assert ref_block_shape is not None
  target_shape = indexer.get_indexer_shape()
  starts, sizes, strides, squeeze_dims, ref_block_shape = (
      _indexer_to_start_size_stride(
          indexer,
          ref_block_shape,
          cast_to_index=False,
      )
  )
  if not all((s is None or s == 1) for s in strides):
    raise NotImplementedError("Strided slices of references are unsupported.")
  dynamic_sizes = tuple(s for s in sizes if isinstance(s, ir.Value))
  ir_dynamic_size = ir.ShapedType.get_dynamic_size()
  static_sizes = tuple(s if not isinstance(s, ir.Value)
                       else ir_dynamic_size for s in sizes)
  target_ref_ty = ir.MemRefType.get(
      static_sizes,
      _dtype_to_ir_type(ref_dtype),
      memory_space=ref.type.memory_space,
  )
  out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes)
  if any(squeeze_dims):
    # We need to squeeze out some dimensions
    static_sizes = tuple(s if not isinstance(s, ir.Value)
                         else ir_dynamic_size for s in target_shape)
    squeezed_ref_ty = ir.MemRefType.get(
        static_sizes,
        _dtype_to_ir_type(ref_dtype),
        memory_space=ref.type.memory_space,
    )
    out = tpu.memref_squeeze(squeezed_ref_ty, out)
  return out, ref_block_shape


def _bitcast_memref(
    ref: ir.Value,
    bitcaster: RefBitcaster,
    ref_dtype: DTypeLike,
    ref_block_shape: tuple[int | pallas_core.Mapped, ...],
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
  src_bitwidth = dtype_bitwidth(ref_dtype)
  dst_bitwidth = dtype_bitwidth(bitcaster.dtype)
  if src_bitwidth != dst_bitwidth:
    if len(ref_block_shape) < 2:
      raise NotImplementedError(
          "Bitcast 1D ref with bitwidth change is not supported."
      )
    if ref_block_shape[-2] is pallas_core.mapped:
      raise NotImplementedError(
          "Bitcast a ref whose 2nd minormost dimension is squeezed when"
          " bitwidth changes."
      )
  new_ref_dtype = bitcaster.dtype
  target_ref_ty = ir.MemRefType.get(
      bitcaster.shape,
      _dtype_to_ir_type(new_ref_dtype),
      memory_space=ref.type.memory_space,
  )
  new_ref_block_shape = list(ref_block_shape)
  if (
      len(new_ref_block_shape) >= 2
      and new_ref_block_shape[-2] is not pallas_core.mapped
  ):
    new_ref_block_shape[-2] = (
        new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth
    )
  return (
      tpu.memref_bitcast(target_ref_ty, ref),
      new_ref_dtype,
      tuple(new_ref_block_shape),
  )


def _reshape_memref(
    ref: ir.Value,
    reshaper: RefReshaper,
    ref_dtype: DTypeLike,
    ref_block_shape: tuple[int | pallas_core.Mapped, ...],
) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]:
  if ref_dtype != reshaper.dtype:
    raise ValueError(
        f"Reshape a ref with dtype change: {reshaper.dtype} vs {ref_dtype}"
    )
  if len(ref_block_shape) < 2:
    raise NotImplementedError("Reshape 1D ref is not supported.")
  if (
      ref_block_shape[-2] is pallas_core.mapped
      or ref_block_shape[-1] is pallas_core.mapped
  ):
    raise NotImplementedError(
        "Reshape a ref with squeezed dimension on last two dimensions."
    )
  if np.prod(ref_block_shape) != np.prod(reshaper.shape):
    raise ValueError(
        f"Reshape a ref with different number of elements: {ref_block_shape} "
        f"vs {reshaper.shape}"
    )
  target_ref_ty = ir.MemRefType.get(
      reshaper.shape,
      _dtype_to_ir_type(reshaper.dtype),
      memory_space=ref.type.memory_space,
  )
  return (
      tpu.memref_reshape(target_ref_ty, ref),
      reshaper.shape,
  )


def _transform_ref(ref, ref_dtype, ref_block_shape, transforms):
  for transform in transforms:
    match transform:
      case NDIndexer():
        ref, ref_block_shape = _slice_memref(
            ref, transform, ref_dtype, ref_block_shape
        )
      case RefBitcaster():
        ref, ref_dtype, ref_block_shape = _bitcast_memref(
            ref, transform, ref_dtype, ref_block_shape
        )
      case RefReshaper():
        ref, ref_block_shape = _reshape_memref(
            ref, transform, ref_dtype, ref_block_shape
        )
      case _:
        raise NotImplementedError(f"Unsupported transform: {transform}")
  return ref, ref_block_shape


@dataclasses.dataclass(frozen=True)
class KeyScalarBundle:
  """A container class for PRNG key data.

  We pass around keys as a KeyScalarBundle in the lowering pass rather than
  as a vector, since we want the key data to live in scalar registers rather
  than vector registers. This special dataclass exists so we can return
  multiple scalar values from load_op, because the load_op primitive does
  not allow multiple results.

  Attributes:
    scalars: A list of OpResults representing scalar key data during the
      lowering pass.
  """
  key_shape: tuple[int, ...]
  scalars: list[ir.OpResult]

def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_):
  ref, transforms, mask, _ = args_tree.unflatten(args_flat)
  ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in)
  (*prev_transforms, idx) = transforms
  # Select last aval, which is the one that will be used for the load.
  (*_, idx_aval) = transforms_avals

  if mask is not None:
    raise NotImplementedError

  ref_block_shape, *_ = ctx.block_shapes
  ref, ref_block_shape = _transform_ref(
      ref, ref_aval.dtype, ref_block_shape, prev_transforms
  )
  ref_type = ir.MemRefType(ref.type)
  is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space<smem>"
  (aval_out,) = ctx.avals_out
  if isinstance(aval_out.dtype, prng.KeyTy) and pl_random.is_pallas_impl(
      aval_out.dtype._impl
  ):
    if not is_smem_load:
      raise ValueError("PRNG keys must be loaded from SMEM. Did you set "
                       "the memory space to TPUMemorySpace.SMEM in the "
                       "BlockSpec for the PRNG key input?")
    return _prng_key_load_lowering_rule(ctx, *args_flat, args_tree=args_tree)
  if not is_smem_load and not ref_block_shape:
    raise NotImplementedError(
        "Indexing into a ()-shaped Ref not yet supported on TPU.")
  if any(
      (not isinstance(a, primitives.Slice) and a.shape)
      for a in idx_aval.indices
  ):
    raise ValueError("Cannot do int indexing on TPU")
  starts, sizes, strides, _, _ = _indexer_to_start_size_stride(
      idx,
      ref_block_shape,
      cast_to_index=True,
  )
  need_stride = not all((s is None or s == 1) for s in strides)
  if is_smem_load:
    if ctx.avals_out[0].shape:
      raise ValueError("Can only load scalars from SMEM")
    return _maybe_cast_load_to_bool(aval_out, memref.load(ref, starts))
  elif str(ref_type.memory_space) != "#tpu.memory_space<vmem>":
    extra = ""
    if str(ref_type.memory_space) == "#tpu.memory_space<any>":
      extra = " ANY memory space can only be accessed using async_copy."
    raise ValueError(
        "Loads are only allowed on VMEM and SMEM references." + extra
    )
  load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype)
  if need_stride:
    load_val = tpu.strided_load(
      aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts, strides
    )
  else:
    load_val = vector.load(
        aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts)
  if load_aval != aval_out:
    vec_type = ir.VectorType.get(aval_out.shape,
                                _dtype_to_ir_type(aval_out.dtype,
                                                  is_kernel_boundary=True))
    load_val = vector.shape_cast(vec_type, load_val)
  return _maybe_cast_load_to_bool(aval_out, load_val)

def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle:
  """Lowering rule for loading PRNG keys from SMEM.

  PRNG key loads are currently lowered as a list of scalar loads from SMEM,
  rather than a single vector load.
  We store these scalars in a bundle type called KeyScalarBundle, which has
  special case handling for functions that consume the key such as set_seed.
  """
  ref, _, _, _ = args_tree.unflatten(args_flat)
  (aval_out,) = ctx.avals_out
  assert isinstance(aval_out.dtype, prng.KeyTy)
  ref_block_shape = aval_out.dtype._impl.key_shape

  if len(ref_block_shape) != 2:
    raise NotImplementedError("Seed key_data must be 2D.")
  if tuple(ref_block_shape) != (1, 1):
    raise NotImplementedError(
      f"Seed key_data of shape != (1, 1) not supported. Got: {ref_block_shape}")

  load_ops = []
  for i in range(ref_block_shape[0]):
    idx = NDIndexer(indices=(0, i), shape=ref_block_shape,
                    int_indexer_shape=tuple())
    starts, _, _, _, _ = _indexer_to_start_size_stride(
        idx,
        ref_block_shape,
        cast_to_index=True,
    )
    load_ops.append(memref.load(ref, starts))
  return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape))


lowering_rules[primitives.load_p] = _load_lowering_rule
skip_mlir_conversions.add(primitives.load_p)

def _maybe_cast_load_to_bool(
    out_aval, val: ir.Value) -> tuple[ir.Value, jnp.dtype]:
  """Casts a memref load value to bool if the requested value is a bool.

  Mosaic does not support boolean-type memrefs, since booleans
  typically live in mask registers. We instead load booleans as integers from
  memrefs and move them to mask registers on load using this function.

  Args:
    out_aval: The output aval of the load.
    val: The input value.

  Returns:
    The loaded value, and the JAX dtype of the input value.
  """
  if out_aval.dtype != jnp.bool_:
    return val
  load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE)
  pred = _cmpsi_lowering_types[lax.ne_p]
  predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
  const_zero = ir.IntegerAttr.get(load_scalar_type, 0)
  if out_aval.shape:  # Vector case.
    load_vector_type = aval_to_ir_type(out_aval, is_kernel_boundary=True)
    vector_zeros = arith.ConstantOp(
        load_vector_type,
        ir.DenseElementsAttr.get_splat(load_vector_type, const_zero)
    )
    return arith.cmpi(predicate, val, vector_zeros)
  else:  # Scalar case.
    const_zero = arith.ConstantOp(load_scalar_type, const_zero)
    return arith.cmpi(predicate, val, const_zero)


def _maybe_cast_store_to_memref_type(
    expected_aval, val: ir.Value) -> ir.Value:
  """Casts a boolean value back to an integer for storing in a memref."""
  if expected_aval.dtype != jnp.bool_:
    return val
  int_out_type = aval_to_ir_type(expected_aval, is_kernel_boundary=True)
  return arith.extui(int_out_type, val)


def _masked_swap_lowering_rule(
    ctx: LoweringRuleContext, *args_flat, args_tree, **_
):
  ref, transforms, val, mask = args_tree.unflatten(args_flat)
  ref_aval, transforms_avals, val_aval, mask_aval = args_tree.unflatten(
      ctx.avals_in
  )
  (*prev_transforms, idx) = transforms
  (*_, idx_aval) = transforms_avals

  if mask is not None:
    if  val_aval.dtype.itemsize != 4:
      raise NotImplementedError("masked swap with non-32-bit data")
    if val_aval.shape != mask_aval.shape:
      raise ValueError(
          "Expected value and mask to have the same shape, but got"
          f" value shape {val_aval.shape} vs. mask shape {mask_aval.shape}."
      )

  ref_block_shape, *_ = ctx.block_shapes
  ref, ref_block_shape = _transform_ref(
      ref, ref_aval.dtype, ref_block_shape, prev_transforms
  )

  ref_type = ir.MemRefType(ref.type)
  memory_space = str(ref_type.memory_space)
  is_smem_store = memory_space == "#tpu.memory_space<smem>"
  is_vmem_store = memory_space == "#tpu.memory_space<vmem>"
  (aval_out,) = ctx.avals_out
  if not isinstance(val, ir.Value):
    val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype))
  if any(
      (not isinstance(a, primitives.Slice) and a.shape)
      for a in idx_aval.indices
  ):
    raise ValueError("Cannot do int indexing on TPU")
  if not is_smem_store and not ref_block_shape:
    raise NotImplementedError(
        "Indexing into a ()-shaped Ref not yet supported on TPU.")

  starts, _, strides, _, _ = _indexer_to_start_size_stride(
      idx,
      ref_block_shape,
      cast_to_index=True,
  )
  need_stride = not all((s is None or s == 1) for s in strides)

  if is_smem_store:
    if mask is not None:
      raise ValueError("SMEM store does not support masks")
    if val_aval.shape:
      raise ValueError("Can only store scalars to SMEM")
    result = memref.load(ref, starts)
    result = _maybe_cast_load_to_bool(val_aval, result)
    val = _maybe_cast_store_to_memref_type(val_aval, val)
    memref.StoreOp(val, ref, starts)
    return result

  if not is_vmem_store:
    extra = ""
    if memory_space == "#tpu.memory_space<any>":
      extra = " ANY memory space can only be accessed using async_copy."
    raise ValueError(
        "Loads and stores are only allowed on VMEM and SMEM references." + extra
    )

  # handling VMEM store below
  if not val_aval.shape:
    raise ValueError("Cannot store scalars to VMEM")

  mem_slice_shape = list(aval_out.shape)
  for i, a in enumerate(idx_aval.indices):
    if not isinstance(a, primitives.Slice):
      mem_slice_shape.insert(i, 1)
  mem_slice_shape_iter = iter(mem_slice_shape)
  mem_slice_shape = [
      1 if b is pallas_core.mapped else next(mem_slice_shape_iter)
      for b in ref_block_shape
  ]
  mem_aval = aval_out.update(shape=tuple(mem_slice_shape), sharding=None)
  mem_aval_vec_type = ir.VectorType.get(mem_aval.shape,
    _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
  if need_stride:
    result = tpu.strided_load(mem_aval_vec_type, ref, starts, strides)
  else:
    result = vector.load(mem_aval_vec_type, ref, starts)
  val = _maybe_cast_store_to_memref_type(val_aval, val)
  if mem_aval != aval_out:
    # We are slicing a scalar so provided dummy 1 indices
    result_vec_type = ir.VectorType.get(aval_out.shape,
      _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True))
    result = vector.shape_cast(result_vec_type, result)
    val_vec_type = ir.VectorType.get(mem_aval.shape,
      _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True))
    val = vector.shape_cast(val_vec_type, val)
  result = _maybe_cast_load_to_bool(val_aval, result)

  if need_stride:
    if mask is not None:
      raise NotImplementedError("masked swap with strided store")
    tpu.StridedStoreOp(val, ref, starts, strides)
  elif jaxlib_version <= (0, 4, 35):
    if mask is not None:
      raise NotImplementedError("masked swap with vector store")
    vector.StoreOp(val, ref, starts)
  else:
    tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
  return result


lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
skip_mlir_conversions.add(primitives.swap_p)


def _multiple_of_lowering_rule(ctx: LoweringRuleContext, val, *, values):
  del ctx
  for multiple in values:
    val = tpu.assume_multiple(val, multiple)
  return val


lowering_rules[primitives.multiple_of_p] = _multiple_of_lowering_rule


def reduce_lowering_rule(reduce_fn, type_to_kind, type_to_identity):
  def _lowering_rule(ctx: LoweringRuleContext, x, *, axes):
    (x_aval,) = ctx.avals_in
    if not ctx.avals_out[0].shape:
      # If reducing to a scalar, we reduce by adding a leading singleton
      # dimension and reducing over all other dimensions. This avoids
      # the materialization of a scalar tensor by the reduction op which
      # is not supported.
      def _proxy_fun(val, *, axes):
        val = val[jnp.newaxis, ...]
        axes = [axis + 1 for axis in axes]
        val = reduce_fn(val, axis=axes, keepdims=True)
        # Squeeze lowers to vector.ExtractOp which will place the final
        # value in a scalar register.
        return jnp.squeeze(val)
      proxy_lowering = lower_fun(
          _proxy_fun, multiple_results=False)
      return proxy_lowering(ctx, x, axes=axes)

    if jnp.issubdtype(x_aval.dtype, jnp.floating):
      kind = type_to_kind[jnp.floating]
      val = type_to_identity[jnp.floating]
      val = ir.FloatAttr.get(ir.F32Type.get(), val)
    elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
      raise NotImplementedError("Reductions over integers not implemented.")
    elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
      raise NotImplementedError("Reductions over integers not implemented.")
    else:
      raise NotImplementedError(
          f"Reductions over {x_aval.dtype} not implemented.")
    out_type = aval_to_ir_type(ctx.avals_out[0])
    identity = ir.DenseElementsAttr.get_splat(out_type, val)
    acc = arith.ConstantOp(out_type, identity)
    return vector.multi_reduction(kind, x, acc, axes)
  return _lowering_rule


REDUCE_MAX_KINDS = {
    jnp.floating: vector.CombiningKind.MAXIMUMF,
    jnp.signedinteger: vector.CombiningKind.MAXSI,
    jnp.unsignedinteger: vector.CombiningKind.MAXUI,
}
REDUCE_MAX_IDENTITY = {
    jnp.floating: float("-inf"),
    jnp.signedinteger: np.iinfo(np.int32).min,
}
_reduce_max_lowering_rule = reduce_lowering_rule(
    jnp.max, REDUCE_MAX_KINDS, REDUCE_MAX_IDENTITY)
lowering_rules[lax.reduce_max_p] = _reduce_max_lowering_rule


REDUCE_MIN_KINDS = {
    jnp.floating: vector.CombiningKind.MINIMUMF,
    jnp.signedinteger: vector.CombiningKind.MINSI,
    jnp.unsignedinteger: vector.CombiningKind.MINUI,
}
REDUCE_MIN_IDENTITY = {
    jnp.floating: float("inf"),
    jnp.signedinteger: np.iinfo(np.int32).max,
}
_reduce_min_lowering_rule = reduce_lowering_rule(
    jnp.min, REDUCE_MIN_KINDS, REDUCE_MIN_IDENTITY)
lowering_rules[lax.reduce_min_p] = _reduce_min_lowering_rule


REDUCE_SUM_KINDS = {
    jnp.floating: vector.CombiningKind.ADD,
    jnp.signedinteger: vector.CombiningKind.ADD,
    jnp.unsignedinteger: vector.CombiningKind.ADD,
}
REDUCE_SUM_IDENTITY = {
    jnp.floating: 0.0,
    jnp.signedinteger: 0,
}
_reduce_sum_lowering_rule = reduce_lowering_rule(
    jnp.sum, REDUCE_SUM_KINDS, REDUCE_SUM_IDENTITY)
lowering_rules[lax.reduce_sum_p] = _reduce_sum_lowering_rule


def _reduce_and_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
  def _proxy_reduce(arg, *, axes):
    # Mosaic currently only supports float reductions, so we cast the boolean
    # arg to a float and use reduce_min to implement reduce_and.
    # TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
    # instead.
    float_arg = jnp.where(arg, 1.0, 0.0)
    return jnp.min(float_arg, axis=axes) > 0.0
  proxy_lowering = lower_fun(
      _proxy_reduce, multiple_results=False)
  return proxy_lowering(ctx, x, axes=axes)

lowering_rules[lax.reduce_and_p] = _reduce_and_lowering_rule


def _reduce_or_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
  def _proxy_reduce(arg, *, axes):
    # Mosaic currently only supports float reductions, so we cast the boolean
    # arg to a float and use reduce_max to implement reduce_or.
    # TODO(b/351017807): Implement this logic in Mosaic MultiDimReductionOp
    # instead.
    float_arg = jnp.where(arg, 1.0, 0.0)
    return jnp.max(float_arg, axis=axes) > 0.0
  proxy_lowering = lower_fun(
      _proxy_reduce, multiple_results=False)
  return proxy_lowering(ctx, x, axes=axes)

lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule


def _broadcast_to_lowering_rule(
    ctx: LoweringRuleContext, x, shape: Sequence[int]
):
  raise RuntimeError(
      "`broadcast_to` is a Triton-specific primitive. Please consider using"
      " `jnp.broadcast_to` instead."
  )


lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule


def _broadcast_in_dim_lowering_rule(
    ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding
):
  del sharding
  (aval_in,) = ctx.avals_in
  (aval_out,) = ctx.avals_out

  if jnp.issubdtype(aval_in.dtype, jnp.bool_):
    # Direct broadcasts for bools are not supported in Mosaic due to booleans
    # living in mask registers and broadcast operating on vregs. Broadcast as an
    # integer instead and cast back to a bool.
    # TODO(b/351019164): Implement this logic in Mosaic BroadcastOp instead.
    def _proxy_fun(val, *, shape, broadcast_dimensions):
      int_val = jnp.where(val, 1, 0)
      bcast_val = jax.lax.broadcast_in_dim(int_val, shape, broadcast_dimensions)
      return bcast_val == 1
    proxy_lowering = lower_fun(
        _proxy_fun, multiple_results=False)
    return proxy_lowering(
        ctx, val, shape=shape, broadcast_dimensions=broadcast_dimensions)

  if broadcast_dimensions:
    out_shape_list = [1] * len(shape)
    for i, s in zip(broadcast_dimensions, aval_in.shape):
      out_shape_list[i] = s
    out_shape = tuple(out_shape_list)
    out_type = ir.VectorType.get(
        out_shape, _dtype_to_ir_type(aval_out.dtype)
    )
    val = vector.shape_cast(out_type, val)
    if out_shape == aval_out.shape:
      return val
  out_type = ir.VectorType.get(
      aval_out.shape, _dtype_to_ir_type(aval_out.dtype)
  )
  return vector.broadcast(out_type, val)


lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule


def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
  """Converts a jax dot dimension numbers to a tpu dot dimension numbers.

  Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
  of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
  rhs_batch_dims)).

  TPU dot dimension numbers are given as an MLIR definition of the form
  #tpu.dot_dimension_numbers - which can be found in the tpu dilect definition
  # file, tpu.td .
  """
  (contracting_dims, batch_dims) = dimension_numbers
  lhs_contracting_dims, rhs_contracting_dims = contracting_dims
  lhs_batch_dims, rhs_batch_dims = batch_dims

  lhs_total_dims = set(range(len(lhs_shape)))
  rhs_total_dims = set(range(len(rhs_shape)))

  lhs_non_contracting_dims = sorted(
      lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims)
  )
  rhs_non_contracting_dims = sorted(
      rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims)
  )

  # Create output_dim_order
  # Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
  # rhs_non_contracting_dims - this assumption is safe to make, as it is
  # the same one made in jax's dot_general.
  output_dim_order = []

  lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))}
  rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))}

  for dim in lhs_batch_dims:
    output_dim_order.append(0)
    output_dim_order.append(lhs_dim_map[dim])

  for dim in lhs_non_contracting_dims:
    output_dim_order.append(0)
    output_dim_order.append(lhs_dim_map[dim])

  for dim in rhs_non_contracting_dims:
    output_dim_order.append(1)
    output_dim_order.append(rhs_dim_map[dim])

  def format_dims(dims):
    return "[" + ", ".join(str(d) for d in dims) + "]"

  all_dims = (
      lhs_contracting_dims,
      rhs_contracting_dims,
      lhs_non_contracting_dims,
      rhs_non_contracting_dims,
      output_dim_order,
      lhs_batch_dims,
      rhs_batch_dims,
  )
  tpu_dim_numbers_str = (
      f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>"
  )

  return ir.Attribute.parse(tpu_dim_numbers_str)


def _dot_general_lowering_rule(
    ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
):
  (lhs_dims, rhs_dims), _ = dimension_numbers
  (aval_out,) = ctx.avals_out
  out_type = aval_to_ir_type(aval_out)
  val_type = out_type.element_type
  if any(
      cls.isinstance(val_type)
      for cls in [
          ir.BF16Type,
          ir.F32Type,
          ir.Float8E5M2Type,
          ir.Float8E4M3FNType,
      ]
  ):
    val = ir.FloatAttr.get(val_type, 0.0)
  elif ir.IntegerType.isinstance(val_type):
    val = ir.IntegerAttr.get(val_type, 0)
  else:
    raise NotImplementedError(ctx.avals_out[0].dtype)
  if any(len(a.shape) != 2 for a in ctx.avals_in):
    raise NotImplementedError(
        f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
    )
  lhs_aval, rhs_aval = ctx.avals_in
  # This is really a matrix-vector product. It only looks like matrix-matrix.
  if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
    if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
      bcast_shape = jnp.broadcast_shapes(
          ctx.avals_in[0].shape, ctx.avals_out[0].shape
      )
      bcast_shape = ir.VectorType.get(
          list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype)
      )
      if ctx.avals_in[0].shape != bcast_shape:
        x = vector.broadcast(bcast_shape, x)
      if ctx.avals_in[1].shape != bcast_shape:
        y = vector.broadcast(bcast_shape, y)
    red_type = aval_to_ir_type(lhs_aval.update(shape=(lhs_aval.shape[0],)))
    acc = arith.ConstantOp(
        red_type, ir.DenseElementsAttr.get_splat(red_type, val)
    )
    red = vector.MultiDimReductionOp(
        ir.Attribute.parse("#vector.kind<add>"),
        arith.MulFOp(x, y),
        acc,
        [1]
    )
    return vector.shape_cast(out_type, red)

  tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
      dimension_numbers, lhs_aval.shape, rhs_aval.shape
  )

  if precision is not None:
    if precision[0] != precision[1]:
      raise NotImplementedError("Per-operand dot precision unsupported")
    precision = precision[0]
  if precision is None or precision == lax.Precision.DEFAULT:
    precision_attr = None  # That's the default in Mosaic.
  elif precision == lax.Precision.HIGHEST:
    precision_attr = ir.Attribute.parse(
        "#tpu.contract_precision<fp32>"
    )
  else:
    raise NotImplementedError(f"Unsupported dot precision: {precision}")
  out_tile = arith.ConstantOp(
      out_type, ir.DenseElementsAttr.get_splat(out_type, val)
  )
  return tpu.matmul(
      out_type,
      x,
      y,
      out_tile,
      dimension_numbers=tpu_dot_dims,
      precision=precision_attr,
  )


lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule

def _convert_helper(x, *, to_dtype):
  # Helper function for dtype conversion
  from_dtype = x.dtype
  if jnp.issubdtype(from_dtype, jnp.dtype("bool")):
    x = x.astype(jnp.int32)
    return _convert_helper(x, to_dtype=to_dtype)
  if jnp.issubdtype(from_dtype, jnp.signedinteger):
    if from_dtype.itemsize < 4:
      x = x.astype(jnp.int32)
    if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
      x = x.astype(jnp.float32)
    return x.astype(to_dtype)
  if jnp.issubdtype(from_dtype, jnp.unsignedinteger):
    if from_dtype.itemsize < 4:
      x = x.astype(jnp.uint32)
    if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4:
      x = x.astype(jnp.float32)
    return x.astype(to_dtype)
  if jnp.issubdtype(from_dtype, jnp.floating):
    if jnp.issubdtype(to_dtype, np.dtype("bool")):
      # Cast to float32 rather than int32 because 0 < |x| < 1 rounds to 0,
      # leading to false in bool. However, convert_element_type(x, bool)
      # returns true. It's handled correctly when x is float32.
      x = x.astype(jnp.float32)
    elif jnp.issubdtype(to_dtype, jnp.signedinteger):
      if from_dtype.itemsize < 4:
        x = x.astype(jnp.float32)
      if to_dtype.itemsize < 4:
        # Need to clip values to match XLA
        minval, maxval = jnp.iinfo(to_dtype).min, jnp.iinfo(to_dtype).max
        x = jnp.clip(x, minval, maxval)
        return x.astype(jnp.int32).astype(to_dtype)
    return x.astype(to_dtype)
  raise NotImplementedError(f"Unsupported cast: {from_dtype} -> {to_dtype}")

def _convert_element_type_lowering_rule(
    ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding
):
  del weak_type
  del sharding
  out_aval = ctx.avals_out[0]
  in_aval = ctx.avals_in[0]
  old_dtype = in_aval.dtype
  out_type = aval_to_ir_type(out_aval)

  if old_dtype == new_dtype:
    return x

  if new_dtype.itemsize == 8:
    raise NotImplementedError("64-bit types are not supported")

  if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
      new_dtype, jnp.floating
  ):
    if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
      return arith.extf(out_type, x)
    elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
      return arith.truncf(out_type, x)
  elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype(
      new_dtype, jnp.integer
  ):
    if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4:
      return arith.extsi(out_type, x)
    elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
      return arith.trunci(out_type, x)
    elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
      # This case triggers when casting signed to unsigned or vice versa.
      return x
  elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
      new_dtype, jnp.signedinteger
  ) and old_dtype.itemsize == new_dtype.itemsize == 4:
    return arith.fptosi(out_type, x)
  elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype(
      new_dtype, jnp.floating
  ) and old_dtype.itemsize == new_dtype.itemsize == 4:
    return arith.sitofp(out_type, x)
  elif (
      old_dtype == jnp.bool_
      and jnp.issubdtype(new_dtype, jnp.integer)
      and new_dtype.itemsize == 4
  ):
    return arith.extui(out_type, x)
  elif (
      (
          (is_float := jnp.issubdtype(old_dtype, jnp.floating))
          or jnp.issubdtype(old_dtype, jnp.integer)
      )
      and new_dtype == jnp.bool_
      and old_dtype.itemsize == 4
  ):
    # Lower float32 or (u)int32 -> bool to cmp neq %in, 0
    const_type = _dtype_to_ir_type(old_dtype)
    if is_float:
      pred = _cmpf_lowering_types[lax.ne_p]
      const_zero = ir.FloatAttr.get(const_type, 0)
      op = arith.CmpFOp
    else:
      pred = _cmpsi_lowering_types[lax.ne_p]
      const_zero = ir.IntegerAttr.get(const_type, 0)
      op = arith.CmpIOp
    predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
    if in_aval.shape:
      in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False)
      vector_zeros = arith.ConstantOp(
          in_type,
          ir.DenseElementsAttr.get_splat(in_type, const_zero),
      )
      return op(predicate, x, vector_zeros).result
    return op(predicate, x, arith.ConstantOp(const_type, const_zero)).result
  return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
                   multiple_results=False)(ctx, x)


lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule


def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions,
                           sharding):
  if dimensions is not None:
    raise NotImplementedError
  if any(d is None for d in new_sizes):
    raise NotImplementedError
  if not ctx.avals_in[0].shape:
    return vector.broadcast(aval_to_ir_type(ctx.avals_out[0]), x)
  return vector.shape_cast(aval_to_ir_type(ctx.avals_out[0]), x)


lowering_rules[lax.reshape_p] = _reshape_lowering_rule


def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
  del dimensions  # Unused.
  (aval_in,) = ctx.avals_in
  (aval_out,) = ctx.avals_out
  if not aval_out.shape:
    if aval_out.dtype.itemsize != 4:
      raise ValueError(
          "Only arrays with 32-bit element types can be converted to scalars,"
          f" but got: {aval_out.dtype}. Try casting the input before squeezing"
          " the scalar."
      )
    return vector.extract(x, [], [0] * len(aval_in.shape))
  return vector.shape_cast(aval_to_ir_type(ctx.avals_out[0]), x)


lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule


def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension):
  out_type = aval_to_ir_type(ctx.avals_out[0])
  return tpu.concatenate(out_type, xs, dimension=dimension)


lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule


def _split_lowering_rule(
    ctx: LoweringRuleContext, x, *, sizes, axis
):
  (x_aval,) = ctx.avals_in
  slice_size = np.array(x_aval.shape, dtype=np.int64)
  starts = np.zeros_like(slice_size)
  strides = np.ones_like(slice_size)
  outs = []
  for size, aval_out in zip(sizes, ctx.avals_out):
    slice_size[axis] = size
    outs.append(
        vector.extract_strided_slice(
            aval_to_ir_type(aval_out), x, starts, slice_size, strides
        )
    )
    starts[axis] += size
  return outs

lowering_rules[lax.split_p] = _split_lowering_rule


def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension,
                        sharding):
  out_type = aval_to_ir_type(ctx.avals_out[0])
  return tpu.iota(out_type, dimension=dimension)


lowering_rules[lax.iota_p] = _iota_lowering_rule


def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation):
  if permutation != (1, 0):
    raise NotImplementedError
  out_type = aval_to_ir_type(ctx.avals_out[0])
  return vector.transpose(out_type, x, permutation)


lowering_rules[lax.transpose_p] = _transpose_lowering_rule


def _bcast(x, y, x_aval, y_aval, out_aval):
  x_dtype = x_aval.dtype
  y_dtype = y_aval.dtype
  if y_aval.weak_type:
    y_dtype = x_aval.dtype
  elif x_aval.weak_type:
    x_dtype = y_aval.dtype
  if isinstance(x, (np.ndarray, np.number, int, float)):
    if getattr(y, "type", None) == ir.IndexType.get():
      mlir_type = y.type
    else:
      mlir_type = _dtype_to_ir_type(x_dtype)
    x = ir_constant(x, mlir_type)
  if isinstance(y, (np.ndarray, np.number, int, float)):
    if getattr(x, "type", None) == ir.IndexType.get():
      mlir_type = x.type
    else:
      mlir_type = _dtype_to_ir_type(y_dtype)
    y = ir_constant(y, mlir_type)
  out_shape = list(out_aval.shape)
  if x_aval.shape != out_aval.shape:
    x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype))
    x = vector.broadcast(x_ty, x)
  if y_aval.shape != out_aval.shape:
    y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype))
    y = vector.broadcast(y_ty, y)
  return x, y


def _add_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.integer):
    return arith.addi(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.addf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.add_p] = _add_lowering_rule
skip_mlir_conversions.add(lax.add_p)
lowering_rules[ad_util.add_any_p] = _add_lowering_rule
skip_mlir_conversions.add(ad_util.add_any_p)


def _max_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
    return arith.maxsi(x, y)
  elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
    return arith.maxui(x, y)
  elif jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.maximumf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.max_p] = _max_lowering_rule
skip_mlir_conversions.add(lax.max_p)


def _min_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
    return arith.minsi(x, y)
  elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
    return arith.minui(x, y)
  elif jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.minimumf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.min_p] = _min_lowering_rule
skip_mlir_conversions.add(lax.min_p)


def _sub_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.integer):
    return arith.subi(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.subf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.sub_p] = _sub_lowering_rule
skip_mlir_conversions.add(lax.sub_p)


def _mul_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.integer):
    return arith.muli(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.mulf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.mul_p] = _mul_lowering_rule
skip_mlir_conversions.add(lax.mul_p)


def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
    return arith.divsi(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
    return arith.divui(x, y)
  elif jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.divf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.div_p] = _div_lowering_rule
skip_mlir_conversions.add(lax.div_p)


def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
    return arith.remsi(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
    return arith.remui(x, y)
  if jnp.issubdtype(aval_out.dtype, jnp.floating):
    return arith.remf(x, y)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.rem_p] = _rem_lowering_rule
skip_mlir_conversions.add(lax.rem_p)


def _abs_lowering_rule(ctx: LoweringRuleContext, x):
  (aval_out,) = ctx.avals_out
  if jnp.issubdtype(aval_out.dtype, jnp.integer):
    return math.absi(x)
  if jnp.issubdtype(aval_out.dtype, jnp.floating):
    return math.absf(x)
  raise NotImplementedError(aval_out.dtype)


lowering_rules[lax.abs_p] = _abs_lowering_rule


def _neg_lowering_rule(ctx: LoweringRuleContext, x):
  (x_aval,) = ctx.avals_in
  new_ctx = ctx.replace(
      avals_in=(jax_core.ShapedArray((), x_aval.dtype), x_aval),
      block_shapes=((), *ctx.block_shapes)
  )
  return _sub_lowering_rule(new_ctx, np.array(0, dtype=x_aval.dtype), x)


lowering_rules[lax.neg_p] = _neg_lowering_rule
skip_mlir_conversions.add(lax.neg_p)


def _sign_lowering_rule(ctx: LoweringRuleContext, x):
  return lower_fun(
      pallas_utils.sign_lowering_helper, multiple_results=False,
  )(ctx, x)


lowering_rules[lax.sign_p] = _sign_lowering_rule


def _nextafter_lowering_rule(ctx: LoweringRuleContext, x, y):
  return lower_fun(
      pallas_utils.nextafter_lowering_helper, multiple_results=False,
  )(ctx, x, y)


lowering_rules[lax.nextafter_p] = _nextafter_lowering_rule


def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
  return math.rsqrt(x)


lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule


def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
  return math.sqrt(x)


lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule


def _square_lowering_rule(ctx: LoweringRuleContext, x):
  if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer):
    return arith.muli(x, x)
  return arith.mulf(x, x)


lowering_rules[lax.square_p] = _square_lowering_rule


def _exp_lowering_rule(ctx: LoweringRuleContext, x):
  return math.exp(x)


lowering_rules[lax.exp_p] = _exp_lowering_rule


def _pow_lowering_rule(ctx: LoweringRuleContext, x, y):
  # jax accepts float base (x) and integer/float exponent (y), and integer
  # exponent is casted to float.
  out_type = aval_to_ir_type(ctx.avals_out[0])
  if jnp.issubdtype(ctx.avals_in[1].dtype, jnp.integer):
    y = arith.sitofp(out_type, y)
  if not isinstance(x, ir.Value) and x == 2.:
    return math.exp2(y)
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  return math.powf(x, y)


lowering_rules[lax.pow_p] = _pow_lowering_rule
skip_mlir_conversions.add(lax.pow_p)


def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, *, y):
  return lower_fun(lax_internal._integer_pow, multiple_results=False)(
      ctx, x, y=y)


lowering_rules[lax.integer_pow_p] = _integer_pow_lowering_rule


def _exp2_lowering_rule(ctx: LoweringRuleContext, x):
  # exp2 in JAX lowers to exp(ln2 * x), not to pow2. We match that behavior
  # here.
  return lower_fun(lambda x: jnp.exp(np.log(2) * x), multiple_results=False)(
      ctx, x)


lowering_rules[lax.exp2_p] = _exp2_lowering_rule
skip_mlir_conversions.add(lax.exp2_p)


def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
  neg_x = arith.negf(x)
  exp_neg_x = math.exp(neg_x)
  aval_out = ctx.avals_out[0]
  out_type = aval_to_ir_type(aval_out)
  if aval_out.shape == ():
    one = ir_constant(1.0, mlir_type=out_type)
  else:
    one = vector.broadcast(out_type, ir_constant(1.0))
  denom = arith.addf(one, exp_neg_x)
  return arith.divf(one, denom)


lowering_rules[lax.logistic_p] = _logistic_lowering_rule


def _sin_lowering_rule(ctx: LoweringRuleContext, x):
  return math.sin(x)


lowering_rules[lax.sin_p] = _sin_lowering_rule


def _cos_lowering_rule(ctx: LoweringRuleContext, x):
  return math.cos(x)


lowering_rules[lax.cos_p] = _cos_lowering_rule


def _tan_lowering_rule(ctx: LoweringRuleContext, x):
  return math.tan(x)


lowering_rules[lax.tan_p] = _tan_lowering_rule


def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
  return math.tanh(x)


lowering_rules[lax.tanh_p] = _tanh_lowering_rule


def _log_lowering_rule(ctx: LoweringRuleContext, x):
  return math.log(x)


lowering_rules[lax.log_p] = _log_lowering_rule


def _log1p_lowering_rule(ctx: LoweringRuleContext, x):
  return math.log1p(x)


lowering_rules[lax.log1p_p] = _log1p_lowering_rule


def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method):
  if rounding_method == 0:
    return math.round(x)
  elif rounding_method == 1:
    return math.roundeven(x)
  else:
    raise NotImplementedError(f"Unsupported rounding method: {rounding_method}")


lowering_rules[lax.round_p] = _round_lowering_rule


def _ceil_lowering_rule(ctx: LoweringRuleContext, x):
  return math.ceil(x)


lowering_rules[lax.ceil_p] = _ceil_lowering_rule


def _floor_lowering_rule(ctx: LoweringRuleContext, x):
  return math.floor(x)


lowering_rules[lax.floor_p] = _floor_lowering_rule


def _clz_lowering_rule(ctx: LoweringRuleContext, x):
  return math.ctlz(x)

lowering_rules[lax.clz_p] = _clz_lowering_rule


def _population_count_lowering_rule(ctx: LoweringRuleContext, x):
  aval_out = ctx.avals_out[0]
  if aval_out.shape == ():
    raise ValueError("Population count is not supported on scalars")
  return math.ctpop(x)

lowering_rules[lax.population_count_p] = _population_count_lowering_rule


# Mapping for signed integer comparisons.
_cmpsi_lowering_types = {
    lax.eq_p: arith.CmpIPredicate.eq,
    lax.ne_p: arith.CmpIPredicate.ne,
    lax.lt_p: arith.CmpIPredicate.slt,
    lax.le_p: arith.CmpIPredicate.sle,
    lax.gt_p: arith.CmpIPredicate.sgt,
    lax.ge_p: arith.CmpIPredicate.sge,
}

# Mapping for unsigned integer comparisons.
_cmpui_lowering_types = {
    lax.eq_p: arith.CmpIPredicate.eq,
    lax.ne_p: arith.CmpIPredicate.ne,
    lax.lt_p: arith.CmpIPredicate.ult,
    lax.le_p: arith.CmpIPredicate.ule,
    lax.gt_p: arith.CmpIPredicate.ugt,
    lax.ge_p: arith.CmpIPredicate.uge,
}

# Mapping for floating point comparisons.
_cmpf_lowering_types = {
    lax.eq_p: arith.CmpFPredicate.OEQ,
    lax.ne_p: arith.CmpFPredicate.ONE,
    lax.lt_p: arith.CmpFPredicate.OLT,
    lax.le_p: arith.CmpFPredicate.OLE,
    lax.gt_p: arith.CmpFPredicate.OGT,
    lax.ge_p: arith.CmpFPredicate.OGE,
}


# The relationship between comparison operations on booleans and boolean
# algebra is as follows:
# eq(x, y) = !(x ^ y)
# ne(x, y) = x ^ y
# lt(x, y) = !x && y
# le(x, y) = !x || y
# gt(x, y) = x && !y
# ge(x, y) = x || !y
def _cmp_boolean_lowering_helper(primitive, x: Array, y: Array):
  """A helper function for lowering comparison operations for boolean inputs.

  Args:
    primitive: A JAX primitive representing a comparison operation, which is
      one of the following: `lax.eq_p` (equals), `lax.ne_p` (not equals),
      `lax.lt_p` (less than), `lax.le_p` (less than or equal to),
      `lax.gt_p` (greater than), or `lax.ge_p` (greater than or equal to).
    x: A boolean array representing the first operand in the comparison.
    y: A boolean array representing the second operand in the comparison.

  Returns:
    A boolean array that is the result of applying the comparison operation
    between `x` and `y` based on the given primitive.

  Raises:
    ValueError: If an unsupported comparison primitive is provided.
  """
  if primitive == lax.eq_p:
    return jnp.logical_not(jnp.logical_xor(x, y))
  elif primitive == lax.ne_p:
    return jnp.logical_xor(x, y)
  elif primitive == lax.lt_p:
    return jnp.logical_and(jnp.logical_not(x), y)
  elif primitive == lax.le_p:
    return jnp.logical_or(jnp.logical_not(x), y)
  elif primitive == lax.gt_p:
    return jnp.logical_and(x, jnp.logical_not(y))
  elif primitive == lax.ge_p:
    return jnp.logical_or(x, jnp.logical_not(y))
  else:
    raise ValueError(f"Unsupported comparison primitive: {primitive}")


def _cmp_lowering_rule(primitive, ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
  x_aval, y_aval = ctx.avals_in
  if x_aval.dtype != y_aval.dtype:
    raise ValueError(
        f"Mixed dtype operands in cmp: {x_aval.dtype}, {y_aval.dtype}"
    )
  dtype = x_aval.dtype

  if jnp.issubdtype(dtype, jnp.bool_):
    return lower_fun(
        functools.partial(_cmp_boolean_lowering_helper, primitive),
        multiple_results=False,
    )(ctx, x, y)

  if jnp.issubdtype(dtype, jnp.integer):
    is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
    pred = (
        _cmpui_lowering_types if is_uint else _cmpsi_lowering_types
    )[primitive]
    predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
    return arith.cmpi(predicate, x, y)

  if jnp.issubdtype(dtype, jnp.floating):
    pred = _cmpf_lowering_types[primitive]
    predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred)
    return arith.cmpf(predicate, x, y)

  raise NotImplementedError(f"Unsupported dtype in cmp: {dtype}")


lowering_rules[lax.eq_p] = functools.partial(_cmp_lowering_rule, lax.eq_p)
lowering_rules[lax.ne_p] = functools.partial(_cmp_lowering_rule, lax.ne_p)
lowering_rules[lax.lt_p] = functools.partial(_cmp_lowering_rule, lax.lt_p)
lowering_rules[lax.le_p] = functools.partial(_cmp_lowering_rule, lax.le_p)
lowering_rules[lax.gt_p] = functools.partial(_cmp_lowering_rule, lax.gt_p)
lowering_rules[lax.ge_p] = functools.partial(_cmp_lowering_rule, lax.ge_p)


def _and_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
  return arith.andi(x, y)


lowering_rules[lax.and_p] = _and_lowering_rule
skip_mlir_conversions.add(lax.and_p)


def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
  out_aval, = ctx.avals_out
  out_type = aval_to_ir_type(out_aval)
  return _not_lowering_rule(ctx, tpu.weird(out_type, x))


lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule


def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
  return arith.ori(x, y)


lowering_rules[lax.or_p] = _or_lowering_rule
skip_mlir_conversions.add(lax.or_p)


def _not_lowering_rule(ctx: LoweringRuleContext, x):
  # The primitive not_p is lowered to
  # https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
  # which is arithmetic for integers and logical for booleans.
  # Lowering to:
  # xor x, -1
  # covers both cases.
  out_aval = ctx.avals_out[0]
  out_scalar_type = _dtype_to_ir_type(out_aval.dtype)
  if not out_aval.shape:
    # Create a scalar constant.
    minus_one = ir_constant(-1, out_scalar_type)
  else:
    # Create a vector constant.
    out_type = aval_to_ir_type(out_aval)
    scalar_minus_one = ir.IntegerAttr.get(out_scalar_type, -1)
    minus_one = arith.ConstantOp(
        out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one)
    )
  return arith.xori(x, minus_one)


lowering_rules[lax.not_p] = _not_lowering_rule

def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args):
  if len(args) > 1:
    raise NotImplementedError("select_n only supported with <= 2 arguments")
  pred_aval, x_aval = ctx.avals_in[:2]
  if pred_aval.dtype != np.dtype(np.bool_):
    lower_ctx = LoweringRuleContext(
        ctx.lowering_context,
        avals_in=[pred_aval],
        avals_out=[pred_aval.update(dtype=np.bool_)],
        block_shapes=[None],
    )
    pred = lower_fun(lambda x: x != 0, multiple_results=False)(lower_ctx, pred)
  if not args:
    return x
  # Assume x and y, which we check above.
  y, = args
  return arith.select(pred, y, x)


lowering_rules[lax.select_n_p] = _select_n_lowering_rule


def _clamp(min, operand, max):
  res = jnp.maximum(operand, min)
  return jnp.minimum(res, max)


def _clamp_lowering_rule(ctx: LoweringRuleContext, min, operand, max):
  """Compute minimum_p(maximum_p(min, operand), max)."""
  return lower_fun(_clamp, multiple_results=False)(ctx, min, operand, max)


lowering_rules[lax.clamp_p] = _clamp_lowering_rule


def _for_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    jaxpr,
    nsteps,
    reverse,
    unroll,
    which_linear,
):
  should_discharge = [
      not isinstance(aval, state.AbstractRef) for aval in ctx.avals_in
  ]
  jaxpr, () = state_discharge.discharge_state(
      jaxpr, (), should_discharge=[False, *should_discharge]
  )
  for i in range(nsteps):
    if reverse:
      i = nsteps - i - 1
    i = ir_constant(i)
    lowering_context = ctx.lowering_context.replace(
        block_shapes=[(), *ctx.block_shapes],
    )
    non_ref_args = jaxpr_subcomp(lowering_context, jaxpr, i, *args)
    non_ref_args_iter = iter(non_ref_args)
    args = [
        next(non_ref_args_iter) if s else a
        for a, s in zip(args, should_discharge)
    ]
  return args


lowering_rules[for_loop.for_p] = _for_lowering_rule


def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
                             jaxpr: jax_core.Jaxpr, start: int | ir.Value,
                             num_steps: int | ir.Value, consts, *args,
                             has_loop_index: bool,
                             unroll: int):
  def _run_body(i, args):
    if has_loop_index:
      lowering_context = ctx.lowering_context.replace(
          block_shapes=ctx.block_shapes)
      args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
    else:
      del i
      lowering_context = ctx.lowering_context.replace(
          block_shapes=ctx.block_shapes[:len(consts)]
          + ctx.block_shapes[len(consts) + 1:],
      )
      args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
    return args

  if (
      not isinstance(start, ir.Value)
      and not isinstance(num_steps, ir.Value)
      and num_steps == unroll
  ):
    # No need for an scf.For. We can just unroll completely
    for i in range(start, start + num_steps):
      args = _run_body(
          ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))),
          args,
      )
    return args
  if unroll != 1:
    raise NotImplementedError(
        f"Only unroll={num_steps=} and unroll=1 supported. Got {unroll=}.")
  lbd = _ensure_mlir_value(start, pallas_core.index_map_grid_aval)
  ubd = arith.addi(lbd, _ensure_mlir_value(num_steps, pallas_core.index_map_grid_aval))
  step = ir_constant(1, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
  for_op = scf.ForOp(lbd, ubd, step, args)
  with ir.InsertionPoint(for_op.body):
    iv = for_op.induction_variable
    inner_args = for_op.inner_iter_args
    inner_out = _run_body(iv, inner_args)
    scf.YieldOp(inner_out)
  return for_op.results


def _scan_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    jaxpr: jax_core.ClosedJaxpr,
    linear: tuple[bool, ...],
    length: int,
    reverse: bool,
    unroll: bool | int,
    num_consts: int,
    num_carry: int,
    _split_transpose: bool,
):
  del _split_transpose
  # Can only handle fori_loop-like scans
  num_extensive = len(args) - num_consts - num_carry
  if num_extensive: raise NotImplementedError
  if reverse: raise NotImplementedError
  del linear, num_extensive, reverse

  jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts
  if jaxpr_consts: raise NotImplementedError
  del jaxpr_consts

  jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop(
      jaxpr, num_consts, num_carry
  )
  consts, args = split_list(args, [num_consts])
  consts_avals, args_avals = split_list(ctx.avals_in, [num_consts])
  if has_loop_index:
    loop_index_start, *args = args
    args_avals = args_avals[1:]
  else:
    loop_index_start = 0
  consts = map(_ensure_mlir_value, consts, consts_avals)
  args = map(_ensure_mlir_value, args, args_avals)
  out = _lower_jaxpr_to_for_loop(
      ctx, jaxpr, loop_index_start, length,
      consts, *args, has_loop_index=has_loop_index,
      unroll=unroll)
  if has_loop_index:
    out = [ir_constant(length,
                       mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
           *out]
  return out
lowering_rules[lax.scan_p] = _scan_lowering_rule
skip_mlir_conversions.add(lax.scan_p)


def _lower_while_via_fori(
    ctx: LoweringRuleContext,
    *args,
    fori_jaxpr,
    cond_nconsts,
    cond_jaxpr,
    body_nconsts,
    body_jaxpr,
):
  _, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts])
  (lb, ub), args = carry[:2], carry[2:]
  for_out = _lower_jaxpr_to_for_loop(
      ctx.replace(
          block_shapes=ctx.block_shapes[: body_nconsts + 1]
          + ctx.block_shapes[body_nconsts + 2 :],
      ),
      fori_jaxpr,
      lb,
      arith.subi(ub, lb),
      body_consts,
      *args,
      has_loop_index=True,
      unroll=1,
  )
  return [ub, ub, *for_out]


def _while_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    cond_nconsts,
    cond_jaxpr,
    body_nconsts,
    body_jaxpr,
):
  # First try to lower via a simpler fori loop, which may optimize better.
  fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop(
      cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts
  )
  if fori_jaxpr is not None:
    return _lower_while_via_fori(
        ctx,
        *args,
        fori_jaxpr=fori_jaxpr,
        cond_nconsts=cond_nconsts,
        cond_jaxpr=cond_jaxpr,
        body_nconsts=body_nconsts,
        body_jaxpr=body_jaxpr,
    )

  # If we fail conversion to fori, fallback to an ordinary while loop.
  cond_consts, body_consts, carry = split_list(
      args, [cond_nconsts, body_nconsts]
  )
  cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = (
      split_list(ctx.block_shapes, [cond_nconsts, body_nconsts])
  )
  carry_types = [a.type for a in carry]
  while_op = scf.WhileOp(carry_types, carry)

  before_block = while_op.before.blocks.append(*carry_types)
  with ir.InsertionPoint.at_block_begin(before_block):
    cond_args = [*cond_consts, *before_block.arguments]
    [cond] = jaxpr_subcomp(
        ctx.lowering_context.replace(
            block_shapes=[*cond_const_block_shapes, *carry_block_shapes]
        ),
        cond_jaxpr.jaxpr,
        *cond_args,
    )
    scf.condition(cond, before_block.arguments)

  after_block = while_op.after.blocks.append(*carry_types)
  with ir.InsertionPoint.at_block_begin(after_block):
    body_args = [*body_consts, *after_block.arguments]
    loop_out = jaxpr_subcomp(
        ctx.lowering_context.replace(
            block_shapes=[*body_const_block_shapes, *carry_block_shapes],
        ),
        body_jaxpr.jaxpr,
        *body_args,
    )
    if loop_out:
      scf.yield_(loop_out)
  return list(while_op.results)


lowering_rules[lax.while_p] = _while_lowering_rule

def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches):
  index, *args = args
  out_types = map(aval_to_ir_type, ctx.avals_out)
  pred = arith.cmpi(
      arith.CmpIPredicate.ne, index, ir_constant(0, index.type)
  )
  if_op = scf.IfOp(pred, out_types, hasElse=True)
  lowering_context = ctx.lowering_context.replace(
      block_shapes=ctx.block_shapes[1:],
  )
  with ir.InsertionPoint(if_op.then_block):
    # TODO(b/300272065): Use `scf.IndexSwitchOp` instead of a cascade of
    # if/else.
    if len(branches) > 2:
      out = _cond_lowering_rule(
          ctx,
          arith.subi(index, ir_constant(1, index.type)),
          *args,
          branches=branches[1:],
      )
    else:
      out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
    scf.YieldOp(out)
  with ir.InsertionPoint(if_op.else_block):
    out = jaxpr_subcomp(lowering_context, branches[0].jaxpr, *args)
    scf.YieldOp(out)
  return if_op.results


lowering_rules[lax.cond_p] = _cond_lowering_rule


def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
  lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
  return jaxpr_subcomp(lowering_context, jaxpr.jaxpr, *args)


lowering_rules[pjit.pjit_p] = _pjit_lowering_rule


def _custom_jvp_call_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    call_jaxpr: jax_core.Jaxpr,
    jvp_jaxpr_thunk: Callable,
    num_consts: int,
    symbolic_zeros: bool,
):
  del jvp_jaxpr_thunk
  if symbolic_zeros: raise NotImplementedError
  if num_consts: raise NotImplementedError
  if call_jaxpr.consts: raise NotImplementedError
  lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
  return jaxpr_subcomp(lowering_context, call_jaxpr.jaxpr, *args)


lowering_rules[custom_derivatives.custom_jvp_call_p] = (
    _custom_jvp_call_lowering_rule)


def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
  del ctx, args, kwargs
  # No-op debug callbacks in Mosaic for now
  return []


lowering_rules[debugging.debug_callback_p] = _debug_callback_lowering_rule


def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int):

  if ctx.lowering_context.user_grid_indices is None:
    raise ValueError(
        f"program id: {axis} was passed, but user did not provide a grid."
    )
  length = len(ctx.lowering_context.user_grid_indices)
  if not (0 <= axis < length):
    raise ValueError(
        f"user passed in program id with axis: {axis}, but grid only has"
        f" length: {length}"
    )
  return ctx.lowering_context.user_grid_indices[axis]
lowering_rules[primitives.program_id_p] = _program_id_lowering_rule

def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int):
  mapped_axes = set(ctx.lowering_context.mapped_dims)
  seen_user_axes = 0
  for i in range(ctx.lowering_context.grid_rank):
    seen_user_axes += int(i not in mapped_axes)
    if seen_user_axes == axis + 1:
      break
  else:
    raise ValueError(
        f"user passed in program id with axis: {axis}, but grid only has"
        f" length: {len(ctx.lowering_context.grid_rank)}"
    )
  return tpu.iteration_bound(i)
lowering_rules[primitives.num_programs_p] = _num_programs_lowering_rule


def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis):
  (out_aval,) = ctx.avals_out
  return tpu.repeat(aval_to_ir_type(out_aval), x, axis, repeats)


lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule


def _roll_lowering_rule(
    ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis
):
  (out_aval,) = ctx.avals_out
  return tpu.dynamic_rotate(
      aval_to_ir_type(out_aval),
      x,
      shift,
      axis,
      stride=stride,
      stride_dimension=stride_axis,
  )


lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule


def _slice_lowering_rule(
    ctx: LoweringRuleContext, x, limit_indices, start_indices, strides
):
  """Lowers a slice to vector dialect."""
  (aval_out,) = ctx.avals_out
  out_type = aval_to_ir_type(aval_out)
  if strides is None:
    strides = [1] * len(start_indices)
  sizes = np.array(limit_indices) - np.array(start_indices)
  return vector.extract_strided_slice(
      out_type, x, start_indices, sizes, strides
  )


lowering_rules[lax.slice_p] = _slice_lowering_rule


def _xor_lowering_rule(ctx: LoweringRuleContext, x, y):
  x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
  return arith.xori(x, y)


lowering_rules[lax.xor_p] = _xor_lowering_rule
skip_mlir_conversions.add(lax.xor_p)


def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d):
  x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
  return arith.shli(x, d)


lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule
skip_mlir_conversions.add(lax.shift_left_p)


def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d):
  x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
  return arith.shrsi(x, d)


lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule
skip_mlir_conversions.add(lax.shift_right_arithmetic_p)


def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d):
  x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out)
  return arith.shrui(x, d)


lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules
skip_mlir_conversions.add(lax.shift_right_logical_p)


def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
  return lower_fun(
      pallas_utils.erf_inv_lowering_helper, multiple_results=False,
  )(ctx, x)


lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule


def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
  del ty
  (out_aval,) = ctx.avals_out
  return tpu.bitcast(aval_to_ir_type(out_aval), x)

lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule

def _bitcast_convert_type_lowering_rule(
    ctx: LoweringRuleContext, x, *, new_dtype):
  (in_aval, ) = ctx.avals_in
  (out_aval,) = ctx.avals_out
  old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype)
  new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype)
  if old_bitwidth != new_bitwidth:
    raise NotImplementedError("Changing bitwidths not supported.")
  return tpu.bitcast(aval_to_ir_type(out_aval), x)
lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule

def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value:
  if isinstance(aval, pallas_core.AbstractMemoryRef):
    memspace = _memory_space_to_mosaic_attribute(aval.memory_space)
    if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype):
      assert aval.memory_space == TPUMemorySpace.SEMAPHORE
      memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE)
      return tpu.sem_alloc(memref_type)
    else:
      out_type = ir.MemRefType.get(
          aval.shape,
          _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True),
          memory_space=memspace)
      return memref.alloca(out_type, [], [])
  elif isinstance(aval, tpu_core.AbstractSemaphore):
    memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE)
    return tpu.sem_alloc(memref_type)
  raise NotImplementedError(f"Cannot allocate {type(aval)}.")


def _run_scoped_lowering_rule(ctx: LoweringRuleContext, *consts, jaxpr):
  out_type = [aval_to_ir_type(aval) for aval in ctx.avals_out]
  region = tpu.RegionOp(out_type)
  in_avals = [v.aval for v in jaxpr.invars]
  with ctx.lowering_context.grid_name_context():
    jaxpr = pe.convert_constvars_jaxpr(jaxpr)
  with ir.InsertionPoint(region.body):
    args = map(_alloc_value, in_avals)
    block_shapes = tuple(a.shape if isinstance(a, state.AbstractRef) else None
                         for a in in_avals)
    ctx = ctx.lowering_context.replace(
        block_shapes=(*ctx.block_shapes, *block_shapes)
    )
    out = jaxpr_subcomp(ctx, jaxpr, *consts, *args)
    tpu.YieldOp(out)
  return region.results


lowering_rules[primitives.run_scoped_p] = _run_scoped_lowering_rule

def _device_id_to_logical(
    ctx: LoweringRuleContext, device_id,
    device_id_type: tpu_primitives.DeviceIdType):
  if device_id_type is tpu_primitives.DeviceIdType.MESH:
    # Mesh means we are passed the mesh coordinates for the device
    device_ids = tree_util.tree_leaves(device_id)
    mesh_strides = ctx.lowering_context.mesh_context.mesh_strides

    i32 = ir.IntegerType.get_signless(32)
    if len(device_ids) == 0:
      return arith.constant(i32, 0)
    return functools.reduce(
        arith.addi,
        (
            arith.muli(a, arith.constant(i32, b))
            for a, b in zip(device_ids, mesh_strides)
        ),
    )
  elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
    return device_id
  raise NotImplementedError(f"Unsupported device id type: {device_id_type}")


def _semaphore_read_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    args_tree,
):
  sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
  sem, transforms = tree_util.tree_unflatten(args_tree, args)
  sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
  return tpu.sem_read(sem)


lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule

def _semaphore_signal_lowering_rule(
    ctx: LoweringRuleContext,
    *args,
    args_tree,
    device_id_type: tpu_primitives.DeviceIdType,
):
  sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
  sem, transforms, value, device_id, core_index = tree_util.tree_unflatten(
      args_tree, args
  )
  sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
  if device_id is not None:
    device_id = _device_id_to_logical(ctx, device_id, device_id_type)
  tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index)
  return []


lowering_rules[tpu_primitives.semaphore_signal_p] = (
    _semaphore_signal_lowering_rule)


def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
  sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in)
  sem, transforms, value = tree_util.tree_unflatten(args_tree, args)
  sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms)
  tpu.sem_wait(sem, value)
  return []
lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule

def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
                             device_id_type: tpu_primitives.DeviceIdType):
  (
      src_ref,
      src_transforms,
      dst_ref,
      dst_transforms,
      sem,
      sem_transforms,
      src_sem,
      src_sem_transforms,
      device_id,
  ) = tree_util.tree_unflatten(tree, args)
  (src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = (
      tree_util.tree_unflatten(tree, ctx.avals_in)
  )
  if src_ref_aval.dtype == jnp.bool_:
    raise NotImplementedError("DMAs with bool dtypes are not supported.")
  block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
  src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2]
  src_ref, _ = _transform_ref(
      src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms
  )
  if src_sem is not None:
    src_sem, _ = _transform_ref(
        src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms
    )
  dst_ref, _ = _transform_ref(
      dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms
  )
  sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
  if device_id is not None:
    device_id = _device_id_to_logical(ctx, device_id, device_id_type)
  tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem,
                  device_id=device_id)
  return []
lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule


def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
                            device_id_type: tpu_primitives.DeviceIdType):
  del device_id_type
  (_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten(
      tree, args)
  (_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
      tree, ctx.avals_in)
  block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
  ref_block_shape = block_shapes[2]
  ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)
  sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms)
  tpu.wait_dma(sem, ref)
  return []
lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule

def _device_id_lowering_rule(ctx: LoweringRuleContext):
  return tpu.device_id()
lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule

def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
  grid_names = ctx.lowering_context.grid_names
  if grid_names and axis_name in grid_names:
    # We are querying a named axis corresponding to a grid dimension.
    return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name))
  # We are querying a named axis corresponding to a mesh dimension.
  device_id = tpu.device_id()
  mesh_context = ctx.lowering_context.mesh_context
  if mesh_context is None:
    raise ValueError("Mesh context is not set.")
  mesh_shape = mesh_context.mesh_shape
  axis_names = mesh_context.axis_names
  axis_index = axis_names.index(axis_name)
  axis_size = ir_constant(mesh_shape[axis_index])
  minor_divisor = ir_constant(
      np.prod(mesh_shape[axis_index + 1 :], dtype=np.int32)
  )
  return arith.remsi(arith.divsi(device_id, minor_divisor), axis_size)
lowering_rules[lax.axis_index_p] = _axis_index_rule

def _get_barrier_semaphore_rule(ctx: LoweringRuleContext):
  memref_type = aval_to_ir_type(ctx.avals_out[0])
  return tpu.sem_barrier(memref_type)
lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule


def _delay_rule(ctx: LoweringRuleContext, nanos: int):
  tpu.delay(nanos)
  return []


lowering_rules[tpu_primitives.delay_p] = _delay_rule


def _debug_print_rule(
    ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool
):
  if any(aval.shape for aval in ctx.avals_in):
    raise NotImplementedError("Only scalar values are supported")

  primitives.check_debug_print_format(fmt, *args)
  if has_placeholders:
    if not all(
        isinstance(arg.type, ir.IntegerType) and arg.type.width == 32
        for arg in args
    ):
      raise TypeError(
          "All arguments must be 32-bit integers when using"
          " placeholders (`{...}`). If you need to print values of other types,"
          " remove placeholders from the format string."
      )

    # TPU expects $0, $1 etc as placeholders.
    tpu_fmt = "".join(
        f"{text}${idx}"
        for idx, (text, _, _, _) in enumerate(string.Formatter().parse(fmt))
    )
  else:
    tpu_fmt = fmt
  tpu.log(args, tpu_fmt, formatted=has_placeholders)
  return ()


lowering_rules[primitives.debug_print_p] = _debug_print_rule


def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds):
  del ctx
  # In the KeyScalarBundle case we unpack the bundle and set the seed with
  # the list of scalars.
  if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle):
    tpu.prng_set_seed_32(seeds[0].scalars)
    return []
  # For integer seeds, we can set the seed directly as PRNGSeed32Op natively
  # takes in a list of integers as input.
  all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds)
  if not all_integers:
    seed_types = [seed.type for seed in seeds]
    raise ValueError(f"All seed data must be scalar integers. Got {seed_types}")
  tpu.prng_set_seed_32(seeds)
  return []
lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule


def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape):
  if len(shape) <= 1:
    # TODO(b/342054464): Support implicit dims for PRNGRandomBitsOp.
    raise NotImplementedError("random_bits only supports rank>=2 outputs.")
  out_aval = ctx.avals_out[0]
  out_type = aval_to_ir_type(out_aval)
  return tpu.prng_random_bits(out_type)
lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule


def random_seed_lowering(ctx, seeds, *, impl):
  seed_lowering = lower_fun(impl.seed, multiple_results=False)
  return seed_lowering(ctx, seeds)
lowering_rules[prng.random_seed_p] = random_seed_lowering


def random_bits_lowering(ctx, keys, *, bit_width, shape):
  assert bit_width == 32, "Only 32-bit PRNG supported."
  aval, = ctx.avals_in
  impl = aval.dtype._impl
  _proxy_fn = impl.random_bits
  if not pl_random.is_pallas_impl(impl):
    def new_lowering(key, bit_width, shape):
      key = jax.random.key_data(key).astype(jnp.uint32)
      return impl.random_bits(key, bit_width, shape)
    _proxy_fn = new_lowering
  bits_lowering = lower_fun(_proxy_fn, multiple_results=False)
  return bits_lowering(ctx, keys, bit_width=bit_width, shape=shape)
lowering_rules[prng.random_bits_p] = random_bits_lowering


def random_fold_in_lowering(ctx, keys, msgs):
  keys_aval, _ = ctx.avals_in
  impl = keys_aval.dtype._impl
  fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False)
  return fold_in_lowering(ctx, keys, msgs)
lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering


def random_unwrap_lowering(ctx, key):
  keys_aval = ctx.avals_in[0]
  impl = keys_aval.dtype._impl
  if not pl_random.is_pallas_impl(impl):
    return key
  assert isinstance(key, KeyScalarBundle)
  # Convert to a vector.
  if tuple(key.key_shape) != (1, 1):
    raise NotImplementedError(
      "Seed key_data of shape != (1, 1) not supported. "
      f"Got: {key.key_shape}")
  scalar = key.scalars[0]
  out_type = ir.VectorType.get(
      key.key_shape, _dtype_to_ir_type(jnp.dtype('int32'))
  )
  val = vector.broadcast(out_type, scalar)
  return val
lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering


def random_wrap_lowering(ctx, key_data, *, impl):
  del ctx
  if not pl_random.is_pallas_impl(impl):
    return key_data
  if isinstance(key_data.type, ir.VectorType):
    # If the key data lives in vregs, need to unpack it to sregs.
    key_data_list = []
    key_data_shape = key_data.type.shape
    if len(key_data_shape) != 2:
      raise NotImplementedError("Seed key_data must be 2D.")
    if tuple(key_data_shape) != (1, 1):
      raise NotImplementedError(
        "Seed key_data of shape != (1, 1) not supported. "
        f"Got: {key_data_shape}")
    for i in range(key_data_shape[1]):
      key_data_list.append(vector.ExtractOp(key_data, [], [0, i]))
    return KeyScalarBundle(
        scalars=key_data_list, key_shape=tuple(key_data_shape))
  if isinstance(key_data, KeyScalarBundle):
    return key_data
  else:
    raise NotImplementedError(f"key_data wrap {type(key_data)}")

lowering_rules[prng.random_wrap_p] = random_wrap_lowering

def _checkify_lowering_rule(
    ctx: LoweringRuleContext, *err_args, err_tree, debug):
  if not tpu_core.runtime_assert_enabled():
    if debug:
      return []
    else:
      raise LoweringException("Non-debug check must be functionalized. "
                              "Enable runtime asserts with "
                              "--jax_pallas_enable_runtime_assert "
                              "or functionalize with checkify.check.")

  assert ctx.lowering_context.ir_context.allow_unregistered_dialects, (
    "allow_unregistered_dialects must be set to True for "
    "runtime assert check.")
  error = jax.tree.unflatten(err_tree, err_args)
  assert len(error._pred) == 1
  assert len(error._metadata) == 1
  assert len(error._payload) == 1
  pred = list(error._pred.items())[0][1]
  metadata = list(error._metadata.items())[0]
  payload = list(error._payload.items())[0][1]
  exception_tree = metadata[1]
  exception = jax.tree.unflatten(exception_tree, payload)
  assert isinstance(exception, checkify.FailedCheckError)

  # check_p has an inverted predicate compared to assert,
  # so we need to compute not(pred) here.
  out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool'))
  minus_one = ir_constant(-1, out_scalar_type)
  not_pred = arith.xori(pred, minus_one)
  attrs = {"msg": ir.StringAttr.get(exception.fmt_string)}
  ir.Operation.create("cf.assert",
                      operands=(not_pred,),
                      attributes=attrs)
  return []
lowering_rules[checkify.check_p] = _checkify_lowering_rule

def _threefry2x32_lowering(ctx, k1, k2, m1, m2):
  def _lower_fun(k1, k2, m1, m2):
    with jax.named_scope("threefry2x32"):
      res = prng._threefry2x32_lowering(k1, k2, m1, m2, use_rolled_loops=False)
    return res

  threefry_lowering = lower_fun(_lower_fun, multiple_results=True)
  return threefry_lowering(ctx, k1, k2, m1, m2)


lowering_rules[prng.threefry2x32_p] = _threefry2x32_lowering


def _iota_2x32_shape_lowering(ctx, *, shape):
  total_elements = np.prod(shape)
  if total_elements > np.iinfo(jnp.int32).max:
    raise NotImplementedError(f"Iota with >{np.iinfo(jnp.int32).max} items.")

  def _lower_fun(shape):
    iota_data = jnp.zeros(shape, dtype=jnp.int32)
    multiplier = 1
    for dim in range(len(shape)-1, -1, -1):
      counts_lo = lax.broadcasted_iota(
          dtype=jnp.int32, shape=shape, dimension=dim
      )
      iota_data += counts_lo * multiplier
      multiplier *= shape[dim]
    counts_hi = jnp.zeros(shape, dtype=jnp.int32)
    return counts_hi, iota_data

  iota_lowering = lower_fun(_lower_fun, multiple_results=True)
  return iota_lowering(ctx, shape=shape)


lowering_rules[prng.iota_2x32_shape_p] = _iota_2x32_shape_lowering


def _pad_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs):
  operand, padding_value = args
  padding_config = kwargs["padding_config"]

  out_type: ir.VectorType = aval_to_ir_type(ctx.avals_in[0])
  if not isinstance(out_type, ir.VectorType):
    raise NotImplementedError("Only vector types are supported.")

  for axis, (low, high, interior) in enumerate(padding_config):
    if low == 0 and high == 0 and interior == 0:
      continue

    def _pad(val):
      shape = list(operand.type.shape)
      shape[axis] = val
      pad_vec_type = ir.VectorType.get(
          shape,
          operand.type.element_type,
      )

      if isinstance(padding_value, ir.OpResult):
        pad = vector.broadcast(pad_vec_type, padding_value)
      else:
        scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value)
        pad = arith.ConstantOp(
            pad_vec_type,
            ir.DenseElementsAttr.get_splat(
                pad_vec_type,
                scalar_attr,
            ),
        ).result
      return pad

    if low != 0:
      pad_low = _pad(low)
      new_shape = out_type.shape
      new_shape[axis] += low
      out_type = ir.VectorType.get(
          new_shape,
          out_type.element_type,
      )
      operand = tpu.concatenate(out_type, [pad_low, operand], dimension=axis)

    if high != 0:
      pad_high = _pad(high)
      new_shape = out_type.shape
      new_shape[axis] += high
      out_type = ir.VectorType.get(
          new_shape,
          out_type.element_type,
      )
      operand = tpu.concatenate(out_type, [operand, pad_high], dimension=axis)

    if interior > 0:
      raise NotImplementedError("Not implemented: interior padding")

  return operand


lowering_rules[lax.pad_p] = _pad_lowering_rule
