Skip to content

Commit abb83cb

Browse files
committed
Refactor Interpolate node to centralize attribute management
Moved attribute ID definitions into `interpAttrs` structure for better maintainability and centralized management. Updated the codebase to reference `interpAttrs` where relevant, replacing duplicate constant definitions and improving code consistency.
1 parent 339405b commit abb83cb

File tree

2 files changed

+55
-74
lines changed

2 files changed

+55
-74
lines changed

src/plugins/intel_cpu/src/nodes/interpolate.cpp

Lines changed: 54 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,10 @@
77
#include "executors/x64/interpolate.hpp"
88
#include "executors/common/interpolate.hpp"
99

10-
#include <cpu/x64/xbyak/xbyak.h>
1110

12-
#include <algorithm>
1311
#include <cassert>
14-
#include <cmath>
1512
#include <common/c_types_map.hpp>
16-
#include <common/primitive_attr.hpp>
1713
#include <common/primitive_hashing_utils.hpp>
18-
#include <common/utils.hpp>
1914
#include <cpu/x64/cpu_isa_traits.hpp>
2015
#include <cstddef>
2116
#include <cstdint>
@@ -24,27 +19,22 @@
2419
#include <oneapi/dnnl/dnnl.hpp>
2520
#include <oneapi/dnnl/dnnl_common.hpp>
2621
#include <string>
27-
#include <unordered_map>
2822
#include <utility>
2923
#include <vector>
3024

3125
#include "common/cpu_memcpy.h"
32-
#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp"
3326
#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
34-
#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp"
3527
#include "cpu/x64/jit_generator.hpp"
3628
#include "cpu_types.h"
3729
#include "dnnl_extension_utils.h"
3830
#include "eltwise.h"
39-
#include "emitters/plugin/x64/jit_emitter.hpp"
4031
#include "emitters/plugin/x64/jit_load_store_emitters.hpp"
4132
#include "fake_quantize.h"
4233
#include "graph_context.h"
4334
#include "memory_desc/cpu_memory_desc.h"
4435
#include "node.h"
4536
#include "nodes/common/blocked_desc_creator.h"
4637
#include "nodes/executors/executor.hpp"
47-
#include "nodes/executors/interpolate.hpp"
4838
#include "nodes/executors/interpolate_list.hpp"
4939
#include "nodes/node_config.h"
5040
#include "onednn/iml_type_mapper.h"
@@ -58,7 +48,6 @@
5848
#include "openvino/op/interpolate.hpp"
5949
#include "shape_inference/shape_inference.hpp"
6050
#include "shape_inference/shape_inference_cpu.hpp"
61-
#include "utils/bfloat16.hpp"
6251
#include "utils/general_utils.h"
6352
#include "utils/ngraph_utils.hpp"
6453
#include "utils/precision_support.h"
@@ -115,9 +104,15 @@ using ngInterpShapeCalcMode = ov::op::v4::Interpolate::ShapeCalcMode;
115104

116105
bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
117106
try {
107+
constexpr size_t DATA_ID = 0;
108+
constexpr size_t SCALES_ID = 2;
109+
constexpr size_t AXES_ID = 3;
110+
constexpr size_t SIZE_OR_SCALE_ID_V11 = 1;
111+
constexpr size_t AXES_ID_V11 = 2;
112+
118113
if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
119-
const auto& interpAttr = interp->get_attrs();
120-
const auto& interpMode = interpAttr.mode;
114+
const auto& tmpInterpAttr = interp->get_attrs();
115+
const auto& interpMode = tmpInterpAttr.mode;
121116
if (!one_of(interpMode,
122117
ngInterpMode::NEAREST,
123118
ngInterpMode::LINEAR,
@@ -127,7 +122,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
127122
return false;
128123
}
129124

130-
const auto& interpCoordTransMode = interpAttr.coordinate_transformation_mode;
125+
const auto& interpCoordTransMode = tmpInterpAttr.coordinate_transformation_mode;
131126
if (!one_of(interpCoordTransMode,
132127
ngInterpCoordTransf::HALF_PIXEL,
133128
ngInterpCoordTransf::PYTORCH_HALF_PIXEL,
@@ -140,7 +135,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
140135
}
141136

142137
if (interpMode == ngInterpMode::NEAREST) {
143-
const auto& interpNearestMode = interpAttr.nearest_mode;
138+
const auto& interpNearestMode = tmpInterpAttr.nearest_mode;
144139
if (!one_of(interpNearestMode,
145140
ngInterpNearMode::ROUND_PREFER_FLOOR,
146141
ngInterpNearMode::ROUND_PREFER_CEIL,
@@ -153,7 +148,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
153148
}
154149
}
155150

156-
const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode;
151+
const auto& interpShapeCalcMode = tmpInterpAttr.shape_calculation_mode;
157152
if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
158153
errorMessage =
159154
"Interpolate-4 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode);
@@ -183,20 +178,20 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
183178
errorMessage = "Only const 'axes' input is supported in Interpolate-4";
184179
return false;
185180
}
186-
} else if (const auto interp = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
187-
const auto& interpAttr = interp->get_attrs();
188-
const auto& interpMode = interpAttr.mode;
181+
} else if (const auto interp_v11 = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
182+
const auto& tmpInterpAttr = interp_v11->get_attrs();
183+
const auto& interpMode = tmpInterpAttr.mode;
189184
if (!one_of(interpMode, ngInterpMode::BILINEAR_PILLOW, ngInterpMode::BICUBIC_PILLOW)) {
190185
errorMessage = "Interpolate-11 does not support interpolate mode: " + ov::as_string(interpMode);
191186
return false;
192187
}
193-
const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode;
188+
const auto& interpShapeCalcMode = tmpInterpAttr.shape_calculation_mode;
194189
if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
195190
errorMessage =
196191
"Interpolate-11 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode);
197192
return false;
198193
}
199-
const size_t dataRank = interp->get_input_partial_shape(DATA_ID).rank().get_length();
194+
const size_t dataRank = interp_v11->get_input_partial_shape(DATA_ID).rank().get_length();
200195
if (dataRank < 2 || dataRank > 4) {
201196
// pillow only resize on H and W. resize on D(depth) is not defined.
202197
errorMessage = "Interpolate-11 does not support input tensor of rank : " + std::to_string(dataRank);
@@ -207,8 +202,8 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
207202
errorMessage = "Only const 'scales_or_sizes' input is supported for static shapes in Interpolate-11";
208203
return false;
209204
}
210-
if (interp->get_input_size() > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
211-
interp->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) {
205+
if (interp_v11->get_input_size() > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
206+
interp_v11->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) {
212207
errorMessage = "Only const 'axes' input is supported in Interpolate-11";
213208
return false;
214209
}
@@ -257,8 +252,8 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
257252
: Node(op, context, InterpolateShapeInferFactory(op)) {
258253
std::string errorMessage;
259254
if (isSupportedOperation(op, errorMessage)) {
260-
dataRank = getInputShapeAtPort(DATA_ID).getRank();
261-
if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
255+
dataRank = getInputShapeAtPort(interpAttrs.DATA_ID).getRank();
256+
if (const auto interp_v4 = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
262257
is_version11 = false;
263258
const auto numInputs = inputShapes.size();
264259
if (numInputs != 3 && numInputs != 4) {
@@ -269,7 +264,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
269264
}
270265
isAxesSpecified = numInputs != 3;
271266

272-
const auto& interpAttr = interp->get_attrs();
267+
const auto& interpAttr = interp_v4->get_attrs();
273268

274269
const auto& interpMode = interpAttr.mode;
275270
if (interpMode == ngInterpMode::NEAREST) {
@@ -351,14 +346,14 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
351346
}
352347

353348
const auto scalesNode =
354-
ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(SCALES_ID));
349+
ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4->get_input_node_shared_ptr(interpAttrs.SCALES_ID));
355350
if (scalesNode) {
356351
scales = scalesNode->cast_vector<float>();
357352
isScaleConstant = true;
358353
}
359354

360355
if (isAxesSpecified) {
361-
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(AXES_ID))
356+
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4->get_input_node_shared_ptr(interpAttrs.AXES_ID))
362357
->cast_vector<int>();
363358
} else {
364359
axes.resize(dataRank);
@@ -396,7 +391,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
396391
if (interpShapeCalcMode == ngInterpShapeCalcMode::SCALES) {
397392
interpAttrs.shapeCalcMode = InterpolateShapeCalcMode::scales;
398393
const auto scalesNode = ov::as_type_ptr<const ov::op::v0::Constant>(
399-
interp->get_input_node_shared_ptr(SIZE_OR_SCALE_ID_V11));
394+
interp->get_input_node_shared_ptr(interpAttrs.SIZE_OR_SCALE_ID_V11));
400395
if (scalesNode) {
401396
scales = scalesNode->cast_vector<float>();
402397
isScaleConstant = true;
@@ -426,7 +421,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
426421
}
427422

428423
if (isAxesSpecified) {
429-
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(AXES_ID_V11))
424+
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(interpAttrs.AXES_ID_V11))
430425
->cast_vector<int>();
431426
if (dataRank == 4 && axes.size() == 2 && axes[0] == 1 && axes[1] == 2) {
432427
interpAttrs.NCHWAsNHWC = true;
@@ -496,7 +491,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
496491
return;
497492
}
498493

499-
ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
494+
ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(interpAttrs.DATA_ID);
500495

501496
#if defined(OV_CPU_WITH_ACL)
502497
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
@@ -519,7 +514,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
519514
ov::element::Type outputPrecision = inputPrecision;
520515

521516
if (!fusedWith.empty()) {
522-
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(DATA_ID);
517+
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(interpAttrs.DATA_ID);
523518
}
524519

525520
#if !defined(OV_CPU_WITH_ACL)
@@ -550,29 +545,29 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
550545
auto& creatorsMap = BlockedDescCreator::getCommonCreators();
551546
auto pushDesc = [&](LayoutType dataFormat,
552547
impl_desc_type implDetail,
553-
bool is_version11,
548+
bool is_version11_desc,
554549
bool useAclExecutor = false) {
555-
config.inConfs[DATA_ID].setMemDesc(
556-
creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(DATA_ID)));
557-
if (is_version11) {
550+
config.inConfs[interpAttrs.DATA_ID].setMemDesc(
551+
creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(interpAttrs.DATA_ID)));
552+
if (is_version11_desc) {
558553
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
559-
config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc(
554+
config.inConfs[interpAttrs.SIZE_OR_SCALE_ID_V11].setMemDesc(
560555
creatorsMap.at(LayoutType::ncsp)
561-
->createSharedDesc(targetShapeType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11)));
556+
->createSharedDesc(targetShapeType, getInputShapeAtPort(interpAttrs.SIZE_OR_SCALE_ID_V11)));
562557
} else {
563-
config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc(
558+
config.inConfs[interpAttrs.SIZE_OR_SCALE_ID_V11].setMemDesc(
564559
creatorsMap.at(LayoutType::ncsp)
565-
->createSharedDesc(scalesType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11)));
560+
->createSharedDesc(scalesType, getInputShapeAtPort(interpAttrs.SIZE_OR_SCALE_ID_V11)));
566561
}
567562

568563
if (isAxesSpecified) {
569-
config.inConfs[AXES_ID_V11].setMemDesc(
570-
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(AXES_ID_V11)));
564+
config.inConfs[interpAttrs.AXES_ID_V11].setMemDesc(
565+
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(interpAttrs.AXES_ID_V11)));
571566
}
572567
} else {
573-
config.inConfs[TARGET_SHAPE_ID].setMemDesc(
568+
config.inConfs[interpAttrs.TARGET_SHAPE_ID].setMemDesc(
574569
creatorsMap.at(LayoutType::ncsp)
575-
->createSharedDesc(targetShapeType, getInputShapeAtPort(TARGET_SHAPE_ID)));
570+
->createSharedDesc(targetShapeType, getInputShapeAtPort(interpAttrs.TARGET_SHAPE_ID)));
576571
config.inConfs[get_scale_id()].setMemDesc(
577572
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(scalesType, getInputShapeAtPort(get_scale_id())));
578573

@@ -644,7 +639,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
644639
}
645640
pushDesc(LayoutType::ncsp, ref, true);
646641
} else {
647-
const auto& dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims();
642+
const auto& dataMinDims = getInputShapeAtPort(interpAttrs.DATA_ID).getMinDims();
648643
bool isBlkApplied = dataRank > 1 && dataMinDims[1] != Shape::UNDEFINED_DIM && dataMinDims[1] > 1;
649644

650645
#if defined(OV_CPU_WITH_ACL)
@@ -703,17 +698,17 @@ bool Interpolate::needShapeInfer() const {
703698
if (lastScales.empty()) {
704699
return true;
705700
}
706-
const auto* scales = getSrcDataAtPortAs<const float>(get_scale_id());
701+
const auto* scales_inf = getSrcDataAtPortAs<const float>(get_scale_id());
707702
for (size_t i = 0; i < lastScales.size(); i++) {
708-
if (lastScales[i] != scales[i]) {
703+
if (lastScales[i] != scales_inf[i]) {
709704
return true;
710705
}
711706
}
712707
} else {
713708
if (lastSizes.empty()) {
714709
return true;
715710
}
716-
const auto* sizes = getSrcDataAtPortAs<const int32_t>(TARGET_SHAPE_ID);
711+
const auto* sizes = getSrcDataAtPortAs<const int32_t>(interpAttrs.TARGET_SHAPE_ID);
717712
for (size_t i = 0; i < lastSizes.size(); i++) {
718713
if (sizes[i] != lastSizes[i]) {
719714
return true;
@@ -726,11 +721,11 @@ bool Interpolate::needShapeInfer() const {
726721
void Interpolate::executeDynamicImpl(const dnnl::stream& strm) {
727722
execute(strm);
728723

729-
const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? TARGET_SHAPE_ID : get_scale_id();
724+
const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? interpAttrs.TARGET_SHAPE_ID : get_scale_id();
730725
const auto& memory = getParentEdgeAt(port)->getMemory();
731726
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) {
732-
const auto* scales = memory.getDataAs<const float>();
733-
lastScales.assign(scales, scales + memory.getDesc().getShape().getElementsCount());
727+
const auto* scales_dyn = memory.getDataAs<const float>();
728+
lastScales.assign(scales_dyn, scales_dyn + memory.getDesc().getShape().getElementsCount());
734729
} else {
735730
const auto* sizes = memory.getDataAs<const int32_t>();
736731
lastSizes.assign(sizes, sizes + memory.getDesc().getShape().getElementsCount());
@@ -743,15 +738,15 @@ bool Interpolate::needPrepareParams() const {
743738

744739
inline int Interpolate::get_scale_id() const {
745740
if (is_version11) {
746-
return SIZE_OR_SCALE_ID_V11;
741+
return interpAttrs.SIZE_OR_SCALE_ID_V11;
747742
}
748-
return SCALES_ID;
743+
return interpAttrs.SCALES_ID;
749744
}
750745
inline int Interpolate::get_axis_id() const {
751746
if (is_version11) {
752-
return AXES_ID_V11;
747+
return interpAttrs.AXES_ID_V11;
753748
}
754-
return AXES_ID;
749+
return interpAttrs.AXES_ID;
755750
}
756751

757752
void Interpolate::prepareParams() {
@@ -764,13 +759,13 @@ void Interpolate::prepareParams() {
764759
THROW_CPU_NODE_ERR("has undefined destination memory");
765760
}
766761

767-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
762+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
768763
if (!srcMemPtr || !srcMemPtr->isDefined()) {
769764
THROW_CPU_NODE_ERR("has undefined input memory");
770765
}
771766

772767
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
773-
auto tsMemPtr = getSrcMemoryAtPort(TARGET_SHAPE_ID);
768+
auto tsMemPtr = getSrcMemoryAtPort(interpAttrs.TARGET_SHAPE_ID);
774769
if (!tsMemPtr || !tsMemPtr->isDefined()) {
775770
THROW_CPU_NODE_ERR("has undefined target shape memory");
776771
}
@@ -884,7 +879,7 @@ void Interpolate::prepareParams() {
884879
}
885880

886881
void Interpolate::createPrimitive() {
887-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
882+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
888883
auto dstMemPtr = getDstMemoryAtPort(0);
889884
if (!srcMemPtr) {
890885
THROW_CPU_NODE_ERR("has null input memory");
@@ -978,7 +973,7 @@ std::vector<float> Interpolate::getScales(const VectorDims& srcDimPad, const Vec
978973

979974
void Interpolate::execute([[maybe_unused]] const dnnl::stream& strm) {
980975
auto dstMemPtr = getDstMemoryAtPort(0);
981-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
976+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
982977

983978
if (execPtr) {
984979
auto* dst_data = dstMemPtr->getDataAs<uint8_t>();

src/plugins/intel_cpu/src/nodes/interpolate.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#pragma once
66

77
#include <cassert>
8-
#include <common/primitive_attr.hpp>
98
#include <cstddef>
109
#include <cstdint>
1110
#include <memory>
@@ -16,27 +15,14 @@
1615

1716
#include "cpu_types.h"
1817
#include "executors/interpolate.hpp"
18+
#include "executors/interpolate_config.hpp"
1919
#include "graph_context.h"
2020
#include "node.h"
2121
#include "openvino/core/node.hpp"
22-
#include "openvino/core/type/element_type.hpp"
23-
24-
#define MAX_INPUT_INTERPOLATE 8
2522

2623
namespace ov::intel_cpu::node {
2724

2825
class Interpolate : public Node {
29-
public:
30-
static constexpr size_t DATA_ID = 0;
31-
static constexpr size_t TARGET_SHAPE_ID = 1;
32-
static constexpr size_t SCALES_ID = 2;
33-
static constexpr size_t AXES_ID = 3;
34-
static constexpr size_t SIZE_OR_SCALE_ID_V11 = 1;
35-
static constexpr size_t AXES_ID_V11 = 2;
36-
static constexpr int CUBIC_GRID_LEN = 4;
37-
static constexpr float PILLOW_BILINEAR_WINDOW_SCALE = 1.0f;
38-
static constexpr float PILLOW_BICUBIC_WINDOW_SCALE = 2.0f;
39-
4026
public:
4127
Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& context);
4228

0 commit comments

Comments
 (0)