Skip to content

Commit 6d9c745

Browse files
committed
Refactor Interpolate node to centralize attribute management
Moved attribute ID definitions into `interpAttrs` structure for better maintainability and centralized management. Updated the codebase to reference `interpAttrs` where relevant, replacing duplicate constant definitions and improving code consistency.
1 parent 339405b commit 6d9c745

File tree

2 files changed

+59
-102
lines changed

2 files changed

+59
-102
lines changed

src/plugins/intel_cpu/src/nodes/interpolate.cpp

Lines changed: 58 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,25 @@
77
#include "executors/x64/interpolate.hpp"
88
#include "executors/common/interpolate.hpp"
99

10-
#include <cpu/x64/xbyak/xbyak.h>
11-
12-
#include <algorithm>
13-
#include <cassert>
14-
#include <cmath>
15-
#include <common/c_types_map.hpp>
16-
#include <common/primitive_attr.hpp>
17-
#include <common/primitive_hashing_utils.hpp>
18-
#include <common/utils.hpp>
1910
#include <cpu/x64/cpu_isa_traits.hpp>
2011
#include <cstddef>
2112
#include <cstdint>
22-
#include <cstdlib>
2313
#include <memory>
2414
#include <oneapi/dnnl/dnnl.hpp>
2515
#include <oneapi/dnnl/dnnl_common.hpp>
2616
#include <string>
27-
#include <unordered_map>
2817
#include <utility>
2918
#include <vector>
3019

3120
#include "common/cpu_memcpy.h"
32-
#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp"
33-
#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
34-
#include "cpu/x64/injectors/jit_uni_quantization_injector.hpp"
35-
#include "cpu/x64/jit_generator.hpp"
3621
#include "cpu_types.h"
37-
#include "dnnl_extension_utils.h"
3822
#include "eltwise.h"
39-
#include "emitters/plugin/x64/jit_emitter.hpp"
40-
#include "emitters/plugin/x64/jit_load_store_emitters.hpp"
4123
#include "fake_quantize.h"
4224
#include "graph_context.h"
4325
#include "memory_desc/cpu_memory_desc.h"
4426
#include "node.h"
4527
#include "nodes/common/blocked_desc_creator.h"
4628
#include "nodes/executors/executor.hpp"
47-
#include "nodes/executors/interpolate.hpp"
4829
#include "nodes/executors/interpolate_list.hpp"
4930
#include "nodes/node_config.h"
5031
#include "onednn/iml_type_mapper.h"
@@ -58,7 +39,6 @@
5839
#include "openvino/op/interpolate.hpp"
5940
#include "shape_inference/shape_inference.hpp"
6041
#include "shape_inference/shape_inference_cpu.hpp"
61-
#include "utils/bfloat16.hpp"
6242
#include "utils/general_utils.h"
6343
#include "utils/ngraph_utils.hpp"
6444
#include "utils/precision_support.h"
@@ -115,9 +95,15 @@ using ngInterpShapeCalcMode = ov::op::v4::Interpolate::ShapeCalcMode;
11595

11696
bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
11797
try {
98+
constexpr size_t DATA_ID = 0;
99+
constexpr size_t SCALES_ID = 2;
100+
constexpr size_t AXES_ID = 3;
101+
constexpr size_t SIZE_OR_SCALE_ID_V11 = 1;
102+
constexpr size_t AXES_ID_V11 = 2;
103+
118104
if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
119-
const auto& interpAttr = interp->get_attrs();
120-
const auto& interpMode = interpAttr.mode;
105+
const auto& tmpInterpAttr = interp->get_attrs();
106+
const auto& interpMode = tmpInterpAttr.mode;
121107
if (!one_of(interpMode,
122108
ngInterpMode::NEAREST,
123109
ngInterpMode::LINEAR,
@@ -127,7 +113,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
127113
return false;
128114
}
129115

130-
const auto& interpCoordTransMode = interpAttr.coordinate_transformation_mode;
116+
const auto& interpCoordTransMode = tmpInterpAttr.coordinate_transformation_mode;
131117
if (!one_of(interpCoordTransMode,
132118
ngInterpCoordTransf::HALF_PIXEL,
133119
ngInterpCoordTransf::PYTORCH_HALF_PIXEL,
@@ -140,7 +126,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
140126
}
141127

142128
if (interpMode == ngInterpMode::NEAREST) {
143-
const auto& interpNearestMode = interpAttr.nearest_mode;
129+
const auto& interpNearestMode = tmpInterpAttr.nearest_mode;
144130
if (!one_of(interpNearestMode,
145131
ngInterpNearMode::ROUND_PREFER_FLOOR,
146132
ngInterpNearMode::ROUND_PREFER_CEIL,
@@ -153,7 +139,7 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
153139
}
154140
}
155141

156-
const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode;
142+
const auto& interpShapeCalcMode = tmpInterpAttr.shape_calculation_mode;
157143
if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
158144
errorMessage =
159145
"Interpolate-4 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode);
@@ -183,20 +169,20 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
183169
errorMessage = "Only const 'axes' input is supported in Interpolate-4";
184170
return false;
185171
}
186-
} else if (const auto interp = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
187-
const auto& interpAttr = interp->get_attrs();
188-
const auto& interpMode = interpAttr.mode;
172+
} else if (const auto interp_v11 = ov::as_type_ptr<const ov::op::v11::Interpolate>(op)) {
173+
const auto& tmpInterpAttr = interp_v11->get_attrs();
174+
const auto& interpMode = tmpInterpAttr.mode;
189175
if (!one_of(interpMode, ngInterpMode::BILINEAR_PILLOW, ngInterpMode::BICUBIC_PILLOW)) {
190176
errorMessage = "Interpolate-11 does not support interpolate mode: " + ov::as_string(interpMode);
191177
return false;
192178
}
193-
const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode;
179+
const auto& interpShapeCalcMode = tmpInterpAttr.shape_calculation_mode;
194180
if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) {
195181
errorMessage =
196182
"Interpolate-11 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode);
197183
return false;
198184
}
199-
const size_t dataRank = interp->get_input_partial_shape(DATA_ID).rank().get_length();
185+
const size_t dataRank = interp_v11->get_input_partial_shape(DATA_ID).rank().get_length();
200186
if (dataRank < 2 || dataRank > 4) {
201187
// pillow only resize on H and W. resize on D(depth) is not defined.
202188
errorMessage = "Interpolate-11 does not support input tensor of rank : " + std::to_string(dataRank);
@@ -207,8 +193,8 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
207193
errorMessage = "Only const 'scales_or_sizes' input is supported for static shapes in Interpolate-11";
208194
return false;
209195
}
210-
if (interp->get_input_size() > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
211-
interp->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) {
196+
if (interp_v11->get_input_size() > 2 && ov::as_type_ptr<const ov::op::v0::Constant>(
197+
interp_v11->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) {
212198
errorMessage = "Only const 'axes' input is supported in Interpolate-11";
213199
return false;
214200
}
@@ -257,8 +243,8 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
257243
: Node(op, context, InterpolateShapeInferFactory(op)) {
258244
std::string errorMessage;
259245
if (isSupportedOperation(op, errorMessage)) {
260-
dataRank = getInputShapeAtPort(DATA_ID).getRank();
261-
if (const auto interp = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
246+
dataRank = getInputShapeAtPort(interpAttrs.DATA_ID).getRank();
247+
if (const auto interp_v4 = ov::as_type_ptr<const ov::op::v4::Interpolate>(op)) {
262248
is_version11 = false;
263249
const auto numInputs = inputShapes.size();
264250
if (numInputs != 3 && numInputs != 4) {
@@ -269,7 +255,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
269255
}
270256
isAxesSpecified = numInputs != 3;
271257

272-
const auto& interpAttr = interp->get_attrs();
258+
const auto& interpAttr = interp_v4->get_attrs();
273259

274260
const auto& interpMode = interpAttr.mode;
275261
if (interpMode == ngInterpMode::NEAREST) {
@@ -351,14 +337,14 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
351337
}
352338

353339
const auto scalesNode =
354-
ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(SCALES_ID));
340+
ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4->get_input_node_shared_ptr(interpAttrs.SCALES_ID));
355341
if (scalesNode) {
356342
scales = scalesNode->cast_vector<float>();
357343
isScaleConstant = true;
358344
}
359345

360346
if (isAxesSpecified) {
361-
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(AXES_ID))
347+
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp_v4->get_input_node_shared_ptr(interpAttrs.AXES_ID))
362348
->cast_vector<int>();
363349
} else {
364350
axes.resize(dataRank);
@@ -396,7 +382,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
396382
if (interpShapeCalcMode == ngInterpShapeCalcMode::SCALES) {
397383
interpAttrs.shapeCalcMode = InterpolateShapeCalcMode::scales;
398384
const auto scalesNode = ov::as_type_ptr<const ov::op::v0::Constant>(
399-
interp->get_input_node_shared_ptr(SIZE_OR_SCALE_ID_V11));
385+
interp->get_input_node_shared_ptr(interpAttrs.SIZE_OR_SCALE_ID_V11));
400386
if (scalesNode) {
401387
scales = scalesNode->cast_vector<float>();
402388
isScaleConstant = true;
@@ -426,7 +412,7 @@ Interpolate::Interpolate(const std::shared_ptr<ov::Node>& op, const GraphContext
426412
}
427413

428414
if (isAxesSpecified) {
429-
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(AXES_ID_V11))
415+
axes = ov::as_type_ptr<const ov::op::v0::Constant>(interp->get_input_node_shared_ptr(interpAttrs.AXES_ID_V11))
430416
->cast_vector<int>();
431417
if (dataRank == 4 && axes.size() == 2 && axes[0] == 1 && axes[1] == 2) {
432418
interpAttrs.NCHWAsNHWC = true;
@@ -496,7 +482,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
496482
return;
497483
}
498484

499-
ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(DATA_ID);
485+
ov::element::Type inputPrecision = getOriginalInputPrecisionAtPort(interpAttrs.DATA_ID);
500486

501487
#if defined(OV_CPU_WITH_ACL)
502488
bool isInputPrecisionSupported = one_of(inputPrecision, ov::element::i8, ov::element::u8, ov::element::f16);
@@ -519,7 +505,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
519505
ov::element::Type outputPrecision = inputPrecision;
520506

521507
if (!fusedWith.empty()) {
522-
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(DATA_ID);
508+
outputPrecision = fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(interpAttrs.DATA_ID);
523509
}
524510

525511
#if !defined(OV_CPU_WITH_ACL)
@@ -550,29 +536,29 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
550536
auto& creatorsMap = BlockedDescCreator::getCommonCreators();
551537
auto pushDesc = [&](LayoutType dataFormat,
552538
impl_desc_type implDetail,
553-
bool is_version11,
539+
bool is_version11_desc,
554540
bool useAclExecutor = false) {
555-
config.inConfs[DATA_ID].setMemDesc(
556-
creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(DATA_ID)));
557-
if (is_version11) {
541+
config.inConfs[interpAttrs.DATA_ID].setMemDesc(
542+
creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(interpAttrs.DATA_ID)));
543+
if (is_version11_desc) {
558544
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
559-
config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc(
545+
config.inConfs[interpAttrs.SIZE_OR_SCALE_ID_V11].setMemDesc(
560546
creatorsMap.at(LayoutType::ncsp)
561-
->createSharedDesc(targetShapeType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11)));
547+
->createSharedDesc(targetShapeType, getInputShapeAtPort(interpAttrs.SIZE_OR_SCALE_ID_V11)));
562548
} else {
563-
config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc(
549+
config.inConfs[interpAttrs.SIZE_OR_SCALE_ID_V11].setMemDesc(
564550
creatorsMap.at(LayoutType::ncsp)
565-
->createSharedDesc(scalesType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11)));
551+
->createSharedDesc(scalesType, getInputShapeAtPort(interpAttrs.SIZE_OR_SCALE_ID_V11)));
566552
}
567553

568554
if (isAxesSpecified) {
569-
config.inConfs[AXES_ID_V11].setMemDesc(
570-
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(AXES_ID_V11)));
555+
config.inConfs[interpAttrs.AXES_ID_V11].setMemDesc(
556+
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(interpAttrs.AXES_ID_V11)));
571557
}
572558
} else {
573-
config.inConfs[TARGET_SHAPE_ID].setMemDesc(
559+
config.inConfs[interpAttrs.TARGET_SHAPE_ID].setMemDesc(
574560
creatorsMap.at(LayoutType::ncsp)
575-
->createSharedDesc(targetShapeType, getInputShapeAtPort(TARGET_SHAPE_ID)));
561+
->createSharedDesc(targetShapeType, getInputShapeAtPort(interpAttrs.TARGET_SHAPE_ID)));
576562
config.inConfs[get_scale_id()].setMemDesc(
577563
creatorsMap.at(LayoutType::ncsp)->createSharedDesc(scalesType, getInputShapeAtPort(get_scale_id())));
578564

@@ -615,8 +601,9 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
615601
pushDesc(LayoutType::nspc, undef, true, true);
616602
pushDesc(LayoutType::ncsp, undef, true, true);
617603
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
618-
if (canUseAclExecutor)
604+
if (canUseAclExecutor) {
619605
return;
606+
}
620607
// fallback to f32 if ref is used
621608
inputPrecision = outputPrecision = ov::element::f32;
622609
#endif
@@ -644,16 +631,17 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
644631
}
645632
pushDesc(LayoutType::ncsp, ref, true);
646633
} else {
647-
const auto& dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims();
634+
const auto& dataMinDims = getInputShapeAtPort(interpAttrs.DATA_ID).getMinDims();
648635
bool isBlkApplied = dataRank > 1 && dataMinDims[1] != Shape::UNDEFINED_DIM && dataMinDims[1] > 1;
649636

650637
#if defined(OV_CPU_WITH_ACL)
651638
interpAttrs.hasPad = hasPad;
652639
pushDesc(LayoutType::nspc, undef, false, true);
653640
pushDesc(LayoutType::ncsp, undef, false, true);
654641
canUseAclExecutor = !supportedPrimitiveDescriptors.empty();
655-
if (canUseAclExecutor)
642+
if (canUseAclExecutor) {
656643
return;
644+
}
657645
// fallback to f32 if ref is used
658646
inputPrecision = outputPrecision = ov::element::f32;
659647
#endif
@@ -703,17 +691,17 @@ bool Interpolate::needShapeInfer() const {
703691
if (lastScales.empty()) {
704692
return true;
705693
}
706-
const auto* scales = getSrcDataAtPortAs<const float>(get_scale_id());
694+
const auto* scales_inf = getSrcDataAtPortAs<const float>(get_scale_id());
707695
for (size_t i = 0; i < lastScales.size(); i++) {
708-
if (lastScales[i] != scales[i]) {
696+
if (lastScales[i] != scales_inf[i]) {
709697
return true;
710698
}
711699
}
712700
} else {
713701
if (lastSizes.empty()) {
714702
return true;
715703
}
716-
const auto* sizes = getSrcDataAtPortAs<const int32_t>(TARGET_SHAPE_ID);
704+
const auto* sizes = getSrcDataAtPortAs<const int32_t>(interpAttrs.TARGET_SHAPE_ID);
717705
for (size_t i = 0; i < lastSizes.size(); i++) {
718706
if (sizes[i] != lastSizes[i]) {
719707
return true;
@@ -726,11 +714,11 @@ bool Interpolate::needShapeInfer() const {
726714
void Interpolate::executeDynamicImpl(const dnnl::stream& strm) {
727715
execute(strm);
728716

729-
const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? TARGET_SHAPE_ID : get_scale_id();
717+
const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? interpAttrs.TARGET_SHAPE_ID : get_scale_id();
730718
const auto& memory = getParentEdgeAt(port)->getMemory();
731719
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) {
732-
const auto* scales = memory.getDataAs<const float>();
733-
lastScales.assign(scales, scales + memory.getDesc().getShape().getElementsCount());
720+
const auto* scales_dyn = memory.getDataAs<const float>();
721+
lastScales.assign(scales_dyn, scales_dyn + memory.getDesc().getShape().getElementsCount());
734722
} else {
735723
const auto* sizes = memory.getDataAs<const int32_t>();
736724
lastSizes.assign(sizes, sizes + memory.getDesc().getShape().getElementsCount());
@@ -743,15 +731,15 @@ bool Interpolate::needPrepareParams() const {
743731

744732
inline int Interpolate::get_scale_id() const {
745733
if (is_version11) {
746-
return SIZE_OR_SCALE_ID_V11;
734+
return interpAttrs.SIZE_OR_SCALE_ID_V11;
747735
}
748-
return SCALES_ID;
736+
return interpAttrs.SCALES_ID;
749737
}
750738
inline int Interpolate::get_axis_id() const {
751739
if (is_version11) {
752-
return AXES_ID_V11;
740+
return interpAttrs.AXES_ID_V11;
753741
}
754-
return AXES_ID;
742+
return interpAttrs.AXES_ID;
755743
}
756744

757745
void Interpolate::prepareParams() {
@@ -764,13 +752,13 @@ void Interpolate::prepareParams() {
764752
THROW_CPU_NODE_ERR("has undefined destination memory");
765753
}
766754

767-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
755+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
768756
if (!srcMemPtr || !srcMemPtr->isDefined()) {
769757
THROW_CPU_NODE_ERR("has undefined input memory");
770758
}
771759

772760
if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) {
773-
auto tsMemPtr = getSrcMemoryAtPort(TARGET_SHAPE_ID);
761+
auto tsMemPtr = getSrcMemoryAtPort(interpAttrs.TARGET_SHAPE_ID);
774762
if (!tsMemPtr || !tsMemPtr->isDefined()) {
775763
THROW_CPU_NODE_ERR("has undefined target shape memory");
776764
}
@@ -884,7 +872,7 @@ void Interpolate::prepareParams() {
884872
}
885873

886874
void Interpolate::createPrimitive() {
887-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
875+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
888876
auto dstMemPtr = getDstMemoryAtPort(0);
889877
if (!srcMemPtr) {
890878
THROW_CPU_NODE_ERR("has null input memory");
@@ -978,7 +966,7 @@ std::vector<float> Interpolate::getScales(const VectorDims& srcDimPad, const Vec
978966

979967
void Interpolate::execute([[maybe_unused]] const dnnl::stream& strm) {
980968
auto dstMemPtr = getDstMemoryAtPort(0);
981-
auto srcMemPtr = getSrcMemoryAtPort(DATA_ID);
969+
auto srcMemPtr = getSrcMemoryAtPort(interpAttrs.DATA_ID);
982970

983971
if (execPtr) {
984972
auto* dst_data = dstMemPtr->getDataAs<uint8_t>();
@@ -1081,22 +1069,6 @@ void Interpolate::execute([[maybe_unused]] const dnnl::stream& strm) {
10811069
}
10821070
}
10831071

1084-
1085-
size_t Interpolate::getSpatialDimsNum(const Dim rank) {
1086-
switch (rank) {
1087-
case 1:
1088-
case 3:
1089-
return 1;
1090-
case 2:
1091-
case 4:
1092-
return 2;
1093-
case 5:
1094-
return 3;
1095-
default:
1096-
OPENVINO_THROW("Can't define number spatial");
1097-
}
1098-
}
1099-
11001072
bool Interpolate::canFuse(const NodePtr& node) const {
11011073
if (!mayiuse(cpu::x64::sse41) || interpAttrs.mode == InterpolateMode::linear ||
11021074
interpAttrs.mode == InterpolateMode::bilinear_pillow || interpAttrs.mode == InterpolateMode::bicubic_pillow ||

0 commit comments

Comments
 (0)