- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 11k
 
[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER #22795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER #22795
Conversation
| 
           👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run  Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add  🚀  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a startup failure on incompatible ROCm GPUs by making the import of rocm_aiter_fa conditional based on an environment variable. The approach is sound. My review includes one high-severity suggestion to improve performance by caching the environment variable lookup, as it currently resides in a hot path.
        
          
                vllm/v1/spec_decode/eagle.py
              
                Outdated
          
        
      | if os.environ.get("VLLM_ROCM_USE_AITER") == "1": | ||
| from vllm.v1.attention.backends.rocm_aiter_fa import ( | ||
| AiterFlashAttentionMetadata) | ||
| allowed_types += (AiterFlashAttentionMetadata, ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling os.environ.get() inside the propose method can introduce performance overhead, as this method is on a hot path during inference. It's better to check the environment variable only once when the module is imported.
I recommend defining a module-level constant at the top of the file:
# At the top of vllm/v1/spec_decode/eagle.py
import os
_VLLM_ROCM_USE_AITER = os.environ.get("VLLM_ROCM_USE_AITER") == "1"Then, you can use this constant here:
if _VLLM_ROCM_USE_AITER:
    from vllm.v1.attention.backends.rocm_aiter_fa import (
        AiterFlashAttentionMetadata)
    allowed_types += (AiterFlashAttentionMetadata, )This change will improve performance by avoiding repeated environment variable lookups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree
| 
           Hi @russellb would you be so kind as to review this PR? Right now you can't start VLLM with ROCM and RDNA3 like 7900XTX  | 
    
        
          
                vllm/v1/spec_decode/eagle.py
              
                Outdated
          
        
      | if os.environ.get("VLLM_ROCM_USE_AITER") == "1": | ||
| from vllm.v1.attention.backends.rocm_aiter_fa import ( | ||
| AiterFlashAttentionMetadata) | ||
| allowed_types += (AiterFlashAttentionMetadata, ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the pre-commit failures under this line
        
          
                vllm/v1/spec_decode/eagle.py
              
                Outdated
          
        
      | (TritonAttentionMetadata, AiterFlashAttentionMetadata, | ||
| FlashAttentionMetadata)) | ||
| allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata) | ||
| if os.environ.get("VLLM_ROCM_USE_AITER") == "1": | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way you can make this more dynamic if it's known what device types would support this vs not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@russellb
I think the architecture names can be used, but it will always have to be expanded. Do you know of another mechanism for this?
For example:
def _is_rocm_gpu_with_matrix_cores() -> bool:
if not torch.cuda.is_available() or not torch.version.hip:
returns False
proof:
device_properties = torch.cuda.get_device_properties(
torch.cuda.current_device())
gcn_arch_name = getattr(device_properties, "gcnArchName", "")
supported_archs = ("gfx908", "gfx90a", "gfx940", "gfx941", "gfx942")
returns any(gcn_arch_name.startswith(arch) for arch in support_archs)
except (RuntimeError, AttributeError):
returns False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JartX
Let's cache the value of  os.environ.get as it's overhead is large, similar to
#17067
And alternative approach is to check if aiter is installed using from importlib.util import find_spec. However, this is also a very costly operation, it should be only called once when a class is initialized of a file is import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
625860e    to
    9bc9f67      
    Compare
  
    Signed-off-by: JartX <[email protected]>
9bc9f67    to
    d23a403      
    Compare
  
    Signed-off-by: JartX <[email protected]>
Signed-off-by: JartX <[email protected]>
Signed-off-by: JartX <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
| 
           cc @tjtanaa  | 
    
| 
           @JartX Maybe let's do this instead. we store the allowed_types in the EagleProposer class as I have wrote simple script to time the overhead. It seems it is quite high, as this cost is incurred every decode step. Usually we are decoding for a few thousand tokens like in thinking mode. So the cost will be multiplied by thousand-fold per request. Proposed solution from importlib.util import find_spec
class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
    
    ...
    self.allowed_attn_types = ()
    if current_platform.is_rocm():
        self.allowed_attn_types += (TritonAttentionMetadata, FlashAttentionMetadata)
        
        if find_spec("aiter"):
                from vllm.v1.attention.backends.rocm_aiter_fa import (
                    AiterFlashAttentionMetadata)
                self.allowed_attn_types += (AiterFlashAttentionMetadata, )
    else:
         self.allowed_attn_types = (FlashAttentionMetadata, TreeAttentionMetadata)
    ...
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
        # [num_tokens]
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
        sampling_metadata: SamplingMetadata,
        mm_embeds: Optional[list[torch.Tensor]] = None,
    ) -> torch.Tensor:
    ...
    assert isinstance(attn_metadata, self.allowed_attn_types)
    ...
     | 
    
          
 ...
from vllm.platforms.rocm import on_mi3xx
...
if current_platform.is_rocm() and find_spec("aiter") and on_mi3xx:
...Then we can revert all the changes from the eagle.py This also handles the case where the   | 
    
| 
           Hi @tjtanaa bad news Crash in other point after apply the last recomendation I would say that this error comes from another point. Are you sure we can't choose either of the two solutions verified above?  | 
    
| 
           @tjtanaa think that now is the better way: Everything goes smoothly and works like a cream.  | 
    
| 
               | 
          ||
| if self.use_cuda_graph and \ | ||
| batch_size <= self.cudagraph_batch_sizes[-1]: | ||
| batch_size <= self.cudagraph_batch_sizes[-1]: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NITs, can you revert all of the unrelated changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi! @tjtanaa
These changes were included so I could pass the precommit. I've been trying to contribute to the project for a short time, and @mgoin told me that precommit normally had to be used:
https://marketplace.visualstudio.com/items?itemName=elagil.pre-commit-helper
https://github.com/pre-commit/pre-commit
https://github.com/vllm-project/vllm/blob/main/.github/workflows/pre-commit.yml
So that it would correctly format the file after the changes.
Sorry if this bothered you. Thank you very much for your time and dedication.
If you find that I have it configured incorrectly, please don't hesitate to let me know.
P.S.: If I remove the spaces and use precommit check again, I get an error, so I have to use the fix. It then adds the spaces back and leaves everything ok.
| 
           @JartX Can you revert all of the unrelated changes? Those changes in indentation and spaces?  | 
    
Signed-off-by: tjtanaa <[email protected]>
        
          
                vllm/v1/spec_decode/eagle.py
              
                Outdated
          
        
      | assert isinstance(attn_metadata, FlashAttentionMetadata) | ||
| # The mypy errors are caused because mypy cannot infer the type of | ||
| # attn_metadata. We add this assert to help mypy. | ||
| assert isinstance(attn_metadata, FlashAttentionMetadata) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JartX I tested using other backend. This will cause issue as
FlashAttentionMetadata is not a generic class.
TreeAttentionMetadata, AiterFlashAttentionMetadata, TritonAttentionMetadata and FlashAttentionMetadata are 4 different instances.
I have opened a PR into your branch JartX#1 . It is a mypy fix through Protocol class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merged. @tjtanaa Thank you very much for helping me with the development and testing. I have very limited hardware and am assimilating the work on VLLM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work on this PR as well @JartX 🥂
…aiter [Bugfix] Fix mypy error with Protocol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if tests pass, thanks!
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Signed-off-by: Duncan Moss <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
…patible with AITER (vllm-project#22795) Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
This PR fixes an issue where VLLM failed to start on ROCm GPUs that are not compatible with the rocm_aiter_fa attention backend. An example of such a GPU is the AMD Radeon RX 7900 XTX, which uses the RDNA 3 architecture.
The bug was introduced in commit 1ee5ead, which hardcoded the loading of the vllm.v1.attention.backends.rocm_aiter_fa module in vllm/v1/spec_decode/eagle.py. This forced VLLM to fail on startup before it could even select a different attention backend.
To solve this, I've added a conditional check that allows the user to explicitly enable this backend. The rocm_aiter_fa module will now only be loaded if the environment variable VLLM_ROCM_USE_AITER is set to 1.
This change ensures that:
Users with ROCm GPUs that are not compatible with the rocm_aiter_fa backend can use VLLM without any startup failures.
Users who do need this backend can still enable it manually, preserving the original functionality.