Skip to content

Commit 0bf3f21

Browse files
authored
Revert "Add mrope op fusion (#3509)" (#3562)
This reverts commit 646c1db. this new ops may lead accuracy problem ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
1 parent 068ed70 commit 0bf3f21

File tree

3 files changed

+4
-123
lines changed

3 files changed

+4
-123
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
from transformers.configuration_utils import PretrainedConfig
77
from vllm.config import ModelConfig, VllmConfig
88
from vllm.model_executor.layers.rotary_embedding import (
9-
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
9+
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
1010

1111
from tests.ut.base import TestBase
1212
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
1313
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
1414

1515
MODEL = "Qwen3-0.6B"
16-
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
1716
MAX_NUM_BATCHED_TOKEND = 10000
1817

1918

@@ -377,86 +376,3 @@ def test_yarn_get_mscale(self, mock_npuplatform):
377376
expected,
378377
places=6,
379378
msg=f"Failed for scale={scale}, mscale={mscale}")
380-
381-
382-
class TestAscendMRotaryEmbedding(unittest.TestCase):
383-
384-
def setUp(self):
385-
# Common setup for tests
386-
self.number_tokens = 3
387-
self.num_head = 8
388-
self.num_kvhead = 8
389-
self.head_size = 128
390-
self.max_position_embeddings = 128000
391-
self.is_neox_style = True
392-
self.rope_theta = 1000000.0
393-
self.positions_1d = torch.tensor([1, 2, 3])
394-
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
395-
396-
self.query = torch.randn(
397-
(self.number_tokens, self.num_head * self.head_size),
398-
dtype=torch.bfloat16)
399-
self.key = torch.randn(
400-
(self.number_tokens, self.num_kvhead * self.head_size),
401-
dtype=torch.bfloat16)
402-
403-
# Qwen2.5-VL mrope section case
404-
self.mrope_section = [16, 24, 24]
405-
406-
self.layer = MRotaryEmbedding(self.head_size,
407-
self.head_size,
408-
self.max_position_embeddings,
409-
base=self.rope_theta,
410-
is_neox_style=self.is_neox_style,
411-
dtype=torch.bfloat16,
412-
mrope_section=self.mrope_section)
413-
414-
self.mock_config = MagicMock()
415-
self.mock_config.torchair_graph_config.enabled = False
416-
417-
def _create_vllm_config(self):
418-
vllm_config = VllmConfig()
419-
model_config = ModelConfig(MODEL_VL,
420-
tokenizer=MODEL_VL,
421-
max_model_len=MAX_NUM_BATCHED_TOKEND)
422-
model_config.hf_config = PretrainedConfig()
423-
vllm_config.model_config = model_config
424-
return vllm_config
425-
426-
@patch('torch_npu.npu_mrope')
427-
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
428-
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
429-
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
430-
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
431-
def test_forward_oot_1d_positions(self, mock_npu_mrope):
432-
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
433-
torch.zeros_like(self.key))
434-
435-
vllm_config = self._create_vllm_config()
436-
with set_ascend_forward_context(None, vllm_config):
437-
result_q, result_k = self.layer.forward_oot(
438-
self.positions_1d, self.query, self.key)
439-
440-
mock_npu_mrope.assert_called_once()
441-
self.assertFalse(torch.isnan(result_q).any().item())
442-
self.assertFalse(torch.isnan(result_k).any().item())
443-
self.assertEqual(result_q.shape, self.query.shape)
444-
445-
@patch('torch_npu.npu_mrope')
446-
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
447-
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
448-
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
449-
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
450-
def test_forward_oot_2d_positions(self, mock_npu_mrope):
451-
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
452-
torch.zeros_like(self.key))
453-
454-
vllm_config = self._create_vllm_config()
455-
with set_ascend_forward_context(None, vllm_config):
456-
result_q, result_k = self.layer.forward_oot(
457-
self.positions_2d, self.query, self.key)
458-
459-
mock_npu_mrope.assert_called_once()
460-
self.assertFalse(torch.isnan(result_q).any().item())
461-
self.assertFalse(torch.isnan(result_k).any().item())
462-
self.assertEqual(result_q.shape, self.query.shape)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch_npu
2323
from vllm.forward_context import get_forward_context
2424
from vllm.model_executor.layers.rotary_embedding import (
25-
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
25+
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
2626
YaRNScalingRotaryEmbedding)
2727

2828
from vllm_ascend.platform import NPUPlatform
@@ -395,37 +395,3 @@ def forward(self,
395395
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
396396
is_neox_style, offsets)
397397
return q_pe, k_pe
398-
399-
400-
class AscendMRotaryEmbedding(MRotaryEmbedding):
401-
402-
def forward_oot(
403-
self,
404-
positions: torch.Tensor,
405-
query: torch.Tensor,
406-
key: torch.Tensor,
407-
):
408-
if self.mrope_section != [16, 24, 24]:
409-
return super().forward_oot(positions, query, key)
410-
411-
import torch_npu
412-
mrope_section = [0, 0, 0
413-
] if positions.ndim == 1 else self.mrope_section
414-
415-
if self.cos_sin_cache.device != query.device: # type: ignore
416-
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
417-
query.device) # type: ignore
418-
419-
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
420-
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
421-
query.dtype) # type: ignore
422-
423-
query, key = torch_npu.npu_mrope(positions,
424-
query.contiguous(),
425-
key.contiguous(),
426-
self.cos_sin_cache.contiguous(),
427-
self.head_size,
428-
mrope_section=mrope_section,
429-
rotary_mode='half')
430-
431-
return query, key

vllm_ascend/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
517517
AscendReplicatedLinear,
518518
AscendRowParallelLinear)
519519
from vllm_ascend.ops.rotary_embedding import (
520-
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
521-
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
520+
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
521+
AscendYaRNRotaryEmbedding)
522522
from vllm_ascend.ops.vocab_parallel_embedding import (
523523
AscendLogitsProcessor, AscendParallelLMHead,
524524
AscendVocabParallelEmbedding)
@@ -528,7 +528,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
528528
"QuickGELU": AscendQuickGELU,
529529
"SiluAndMul": AscendSiluAndMul,
530530
"RotaryEmbedding": AscendRotaryEmbedding,
531-
"MRotaryEmbedding": AscendMRotaryEmbedding,
532531
"ColumnParallelLinear": AscendColumnParallelLinear,
533532
"RowParallelLinear": AscendRowParallelLinear,
534533
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,

0 commit comments

Comments
 (0)