Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ tensor 数学操作
" :ref:`paddle.asin <cn_api_paddle_asin>` ", "arcsine 函数"
" :ref:`paddle.atan <cn_api_paddle_atan>` ", "arctangent 函数"
" :ref:`paddle.atan2 <cn_api_paddle_atan2>` ", "arctangent2 函数"
" :ref:`paddle.baddbmm <cn_api_paddle_baddbmm>` ", "对两个批量矩阵 x 和 y 进行乘法运算,将结果乘以标量 alpha,再加上 input 与 beta 的乘积,得到输出"
" :ref:`paddle.ceil <cn_api_paddle_ceil>` ", "向上取整运算函数"
" :ref:`paddle.clip <cn_api_paddle_clip>` ", "将输入的所有元素进行剪裁,使得输出元素限制在[min, max]内"
" :ref:`paddle.conj <cn_api_paddle_conj>` ", "逐元素计算 Tensor 的共轭运算"
Expand Down
35 changes: 35 additions & 0 deletions docs/api/paddle/baddbmm_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
.. _cn_api_paddle_baddbmm:

baddbmm
-------------------------------

.. py:function:: paddle.baddbmm(input, x, y, alpha=1.0, beta=1.0, name=None)




计算 x 和 y 的批量矩阵乘积,将结果乘以标量 alpha,再加上 input 与标量 beta 的乘积,得到输出。其中 input 与 x、y 乘积的维度必须是可广播的。

计算过程的公式为:

.. math::
out = alpha * x * y + beta * input

参数
::::::::::::

- **input** (Tensor) - 输入 Tensor input,数据类型支持 bfloat16、float16、float32、float64。
- **x** (Tensor) - 输入 Tensor x,数据类型支持 bfloat16、float16、float32、float64。
- **y** (Tensor) - 输入 Tensor y,数据类型支持 bfloat16、float16、float32、float64。
- **alpha** (float,可选) - 乘以 x*y 的标量,数据类型支持 float,默认值为 1.0。
- **beta** (float,可选) - 乘以 input 的标量,数据类型支持 float,默认值为 1.0。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
::::::::::::
计算得到的 Tensor。Tensor 数据类型与输入 input 数据类型一致。

代码示例
::::::::::::

COPY-FROM: paddle.baddbmm
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
## [ 组合替代实现 ]torch.baddbmm
## [ torch 参数更多]torch.baddbmm

### [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm)

```python
torch.baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None)
```
Paddle 无此 API,需要组合实现。

### [paddle.baddbmm](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/baddbmm_cn.html)

```python
paddle.baddbmm(input, x, y, alpha=1.0, beta=1.0, name=None)
```

PyTorch 相比 Paddle 支持更多其他参数,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
| ------- | ------- | ------- |
| input | input | 表示输入的 Tensor 。 |
| batch1 | x | 表示输入的第一个 Tensor ,仅参数名不一致。 |
| batch2 | y | 表示输入的第二个 Tensor ,仅参数名不一致。 |
| beta | beta | 表示乘以 input 的标量。 |
| alpha | alpha | 表示乘以 batch1 * batch2 的标量。 |
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要转写。 |

### 转写示例

#### out: 输出的 Tensor

```python
# PyTorch 写法
y = torch.baddbmm(input, batch1, batch2, beta=beta, alpha=alpha)
torch.baddbmm(input, batch1, batch2, beta, alpha, out=output)

# Paddle 写法
y = beta * input + alpha * paddle.bmm(batch1, batch2)
paddle.assign(paddle.baddbmm(input, x, y, beta, alpha), output)
```