Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.xpu

XPUFP8ScaledMMLinearKernel

Bases: FP8ScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if not current_platform.is_xpu():
            return False, "XPUFP8ScaledMM only support on XPU"
        return True, None

    @classmethod
    def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
        if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
            return False, "XPUFP8ScaledMM only support FP8 weight dtype"
        return True, None

    def __init__(
        self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
    ) -> None:
        assert self.can_implement(c)[0]
        assert self.is_supported()[0]
        self.config = c
        self.layer_param_names = layer_param_names

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        weight = layer.weight
        weight_scale = layer.weight_scale
        return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)

    def apply_scaled_mm(
        self,
        *,
        A: torch.Tensor,
        B: torch.Tensor,
        out_dtype: torch.dtype,
        As: torch.Tensor,
        Bs: torch.Tensor,
        bias: torch.Tensor | None,
        output_shape: list,
    ) -> torch.Tensor:
        pass

config instance-attribute

config = c

layer_param_names instance-attribute

layer_param_names = layer_param_names

__init__

__init__(
    c: FP8ScaledMMLinearLayerConfig,
    layer_param_names: Sequence[str],
) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
def __init__(
    self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
    assert self.can_implement(c)[0]
    assert self.is_supported()[0]
    self.config = c
    self.layer_param_names = layer_param_names

apply_scaled_mm

apply_scaled_mm(
    *,
    A: Tensor,
    B: Tensor,
    out_dtype: dtype,
    As: Tensor,
    Bs: Tensor,
    bias: Tensor | None,
    output_shape: list,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
def apply_scaled_mm(
    self,
    *,
    A: torch.Tensor,
    B: torch.Tensor,
    out_dtype: torch.dtype,
    As: torch.Tensor,
    Bs: torch.Tensor,
    bias: torch.Tensor | None,
    output_shape: list,
) -> torch.Tensor:
    pass

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Tensor | None = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
def apply_weights(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    weight = layer.weight
    weight_scale = layer.weight_scale
    return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)

can_implement classmethod

can_implement(
    c: FP8ScaledMMLinearLayerConfig,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
    if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
        return False, "XPUFP8ScaledMM only support FP8 weight dtype"
    return True, None

is_supported classmethod

is_supported(
    compute_capability: int | None = None,
) -> tuple[bool, str | None]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/xpu.py
@classmethod
def is_supported(
    cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
    if not current_platform.is_xpu():
        return False, "XPUFP8ScaledMM only support on XPU"
    return True, None