class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
if self.num_speculative_tokens != self.layer_num:
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"does not match layer_num, adjusting to layer_num"
)
self.num_speculative_tokens = self.layer_num
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
assert multi_layer_eagle_metadata is not None
if token_indices_to_sample is None:
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0] if target_positions.dim() == 2 else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
# NOTE: ignore cpu data to avoid device sync
# common_attn_metadata.seq_lens_cpu.copy_(common_attn_metadata.seq_lens,
# non_blocking=True)
# query_lens = common_attn_metadata.query_start_loc[
# 1:] - common_attn_metadata.query_start_loc[:-1]
# num_computed_tokens = common_attn_metadata.seq_lens - query_lens.to(
# common_attn_metadata.seq_lens.dtype)
# common_attn_metadata.num_computed_tokens_cpu.copy_(
# num_computed_tokens.to(
# common_attn_metadata.num_computed_tokens_cpu.dtype),
# non_blocking=True,
# )
# common_attn_metadata.max_seq_len =
# int(common_attn_metadata.seq_lens_cpu.max().item())
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=multi_layer_eagle_metadata.cached_token_ids,
cached_prev_positions=multi_layer_eagle_metadata.cached_positions,
cached_prev_hidden_states=multi_layer_eagle_metadata.cached_hidden_states,
cached_slot_mappings=multi_layer_eagle_metadata.cached_slot_mappings,
common_attn_metadata=common_attn_metadata,
)
return prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
if use_cudagraphs:
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
num_input_tokens = batch_desc.num_tokens
else:
cudagraph_runtime_mode = CUDAGraphMode.NONE
num_input_tokens = num_tokens_dp_padded
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1], dtype=torch.int32, device=self.device
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device=self.device
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device="cpu"
),
seq_lens=torch.tensor(
[num_input_tokens], dtype=torch.int32, device=self.device
),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=num_input_tokens,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)