Skip to content

Commit 17d5773

Browse files
authored
【Hackathon 8th No.2】为 Paddle 新增 baddbmm API (#7041)
1 parent 76e31b0 commit 17d5773

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

docs/api/paddle/Overview_cn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ tensor 数学操作
4848
" :ref:`paddle.asin <cn_api_paddle_asin>` ", "arcsine 函数"
4949
" :ref:`paddle.atan <cn_api_paddle_atan>` ", "arctangent 函数"
5050
" :ref:`paddle.atan2 <cn_api_paddle_atan2>` ", "arctangent2 函数"
51+
" :ref:`paddle.baddbmm <cn_api_paddle_baddbmm>` ", "对两个批量矩阵 x 和 y 进行乘法运算,将结果乘以标量 alpha,再加上 input 与 beta 的乘积,得到输出"
5152
" :ref:`paddle.ceil <cn_api_paddle_ceil>` ", "向上取整运算函数"
5253
" :ref:`paddle.clip <cn_api_paddle_clip>` ", "将输入的所有元素进行剪裁,使得输出元素限制在[min, max]内"
5354
" :ref:`paddle.conj <cn_api_paddle_conj>` ", "逐元素计算 Tensor 的共轭运算"

docs/api/paddle/baddbmm_cn.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
.. _cn_api_paddle_baddbmm:
2+
3+
baddbmm
4+
-------------------------------
5+
6+
.. py:function:: paddle.baddbmm(input, x, y, alpha=1.0, beta=1.0, name=None)
7+
8+
9+
10+
11+
计算 x 和 y 的批量矩阵乘积,将结果乘以标量 alpha,再加上 input 与标量 beta 的乘积,得到输出。其中 input 与 x、y 乘积的维度必须是可广播的。
12+
13+
计算过程的公式为:
14+
15+
.. math::
16+
out = alpha * x * y + beta * input
17+
18+
参数
19+
::::::::::::
20+
21+
- **input** (Tensor) - 输入 Tensor input,数据类型支持 bfloat16、float16、float32、float64。
22+
- **x** (Tensor) - 输入 Tensor x,数据类型支持 bfloat16、float16、float32、float64。
23+
- **y** (Tensor) - 输入 Tensor y,数据类型支持 bfloat16、float16、float32、float64。
24+
- **alpha** (float,可选) - 乘以 x*y 的标量,数据类型支持 float,默认值为 1.0。
25+
- **beta** (float,可选) - 乘以 input 的标量,数据类型支持 float,默认值为 1.0。
26+
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
27+
28+
返回
29+
::::::::::::
30+
计算得到的 Tensor。Tensor 数据类型与输入 input 数据类型一致。
31+
32+
代码示例
33+
::::::::::::
34+
35+
COPY-FROM: paddle.baddbmm
Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,38 @@
1-
## [ 组合替代实现 ]torch.baddbmm
1+
## [ torch 参数更多]torch.baddbmm
22

33
### [torch.baddbmm](https://pytorch.org/docs/stable/generated/torch.baddbmm.html?highlight=baddbmm#torch.baddbmm)
44

55
```python
66
torch.baddbmm(input, batch1, batch2, beta=1, alpha=1, out=None)
77
```
8-
Paddle 无此 API,需要组合实现。
8+
9+
### [paddle.baddbmm](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/baddbmm_cn.html)
10+
11+
```python
12+
paddle.baddbmm(input, x, y, alpha=1.0, beta=1.0, name=None)
13+
```
14+
15+
PyTorch 相比 Paddle 支持更多其他参数,具体如下:
16+
17+
### 参数映射
18+
19+
| PyTorch | PaddlePaddle | 备注 |
20+
| ------- | ------- | ------- |
21+
| input | input | 表示输入的 Tensor 。 |
22+
| batch1 | x | 表示输入的第一个 Tensor ,仅参数名不一致。 |
23+
| batch2 | y | 表示输入的第二个 Tensor ,仅参数名不一致。 |
24+
| beta | beta | 表示乘以 input 的标量。 |
25+
| alpha | alpha | 表示乘以 batch1 * batch2 的标量。 |
26+
| out | - | 表示输出的 Tensor , Paddle 无此参数,需要转写。 |
927

1028
### 转写示例
1129

30+
#### out: 输出的 Tensor
31+
1232
```python
1333
# PyTorch 写法
14-
y = torch.baddbmm(input, batch1, batch2, beta=beta, alpha=alpha)
34+
torch.baddbmm(input, batch1, batch2, beta, alpha, out=output)
1535

1636
# Paddle 写法
17-
y = beta * input + alpha * paddle.bmm(batch1, batch2)
37+
paddle.assign(paddle.baddbmm(input, x, y, beta, alpha), output)
1838
```

0 commit comments

Comments
 (0)