def maybe_create_ubatch_slices(
should_ubatch: bool,
num_scheduled_tokens: np.ndarray,
num_tokens_padded: int,
num_reqs_padded: int,
num_microbatches: int,
split_point: list[int] | int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch:
return None, None
if split_point is None:
split_point = int(num_tokens_padded) // num_microbatches
token_split_points = [split_point * i for i in range(1, num_microbatches)]
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
ubatch_slices = []
start_token = 0
# Add the end point to the split points to make iteration easier
all_points = token_split_points + [cu_num_tokens[-1]]
for end_token in all_points:
token_slice = slice(start_token, end_token)
# Determine request slices using exclusive stop semantics
# Ubatch includes requests whose tokens overlap [start_token, end_token)
# Start at the request that contains the start_token
# or the request starting exactly at start_token (if on boundary)
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
# Stop at the request that starts at or after end_token
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
req_slice = slice(req_start, req_stop)
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
start_token = end_token
ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded
)
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded