Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.

Commit cbda3e1

Browse files
committed
Refactor CompiledFunction to remove per-run state (V2)
1 parent da381d5 commit cbda3e1

File tree

13 files changed

+233
-224
lines changed

13 files changed

+233
-224
lines changed

include/glow/Backends/CompiledFunction.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,21 @@ class CompiledFunction {
3838
virtual ~CompiledFunction() = default;
3939
/// Execute the network and allocate Placeholder memory with given
4040
/// \p ctx providing mapping between Placeholder and populated tensor.
41-
virtual void execute() = 0;
41+
virtual void execute(Context *ctx) = 0;
4242

4343
/// Does any needed initialization work for the Backend.
4444
/// This includes device init constant memory allocation and copying to
45-
/// device.
46-
virtual void setupRuns() = 0;
45+
/// device. \deprecated
46+
virtual void setupRuns() { runsSetup_ = true; }
4747

48-
/// Per run setup. Copy inputs to device.
49-
virtual void beforeRun(const Context &ctx) = 0;
48+
/// Per run setup. Copy inputs to device. \deprecated
49+
virtual void beforeRun(const Context &ctx) {}
5050

51-
/// Per run cleanup. Copy outputs from device.
52-
virtual void afterRun(const Context &ctx) = 0;
51+
/// Per run cleanup. Copy outputs from device. \deprecated
52+
virtual void afterRun(const Context &ctx) {}
5353

54-
/// Final cleanup. Release memory, reset device.
55-
virtual void tearDownRuns() = 0;
54+
/// Final cleanup. Release memory, reset device. \deprecated
55+
virtual void tearDownRuns() { runsSetup_ = false; }
5656

5757
/// Getter for the runtimeBundle.
5858
const runtime::RuntimeBundle &getRuntimeBundle() const {

lib/Backends/CPU/CPUDeviceManager.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,7 @@ void CPUDeviceManager::runFunctionImpl(RunIdentifierTy id, std::string function,
9292
CompiledFunction *func = funcIt->second;
9393

9494
// Run that function.
95-
func->setupRuns();
96-
func->beforeRun(*ctx);
97-
func->execute();
98-
func->afterRun(*ctx);
99-
func->tearDownRuns();
95+
func->execute(ctx.get());
10096

10197
// Fire the resultCB.
10298
resultCB(id, ResultCode::Executed, std::move(ctx));

lib/Backends/CPU/CPUFunction.cpp

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,63 +30,55 @@ CPUFunction::~CPUFunction() {
3030
tearDownRuns();
3131
}
3232

33-
void CPUFunction::setupRuns() {
34-
if (!runsSetup_) {
35-
if (runtimeBundle_.getActivationsSize() != 0) {
36-
baseActivationsAddress_ = (uint8_t *)alignedAlloc(
37-
runtimeBundle_.getActivationsSize(), TensorAlignment);
38-
}
39-
40-
if (runtimeBundle_.getMutableWeightSize() != 0) {
41-
baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc(
42-
runtimeBundle_.getMutableWeightSize(), TensorAlignment);
43-
}
44-
runsSetup_ = true;
45-
}
46-
}
47-
4833
void CPUFunction::collectConstants(IRFunction *F) {
4934
runtimeBundle_.collectConstants(F);
5035
}
5136

52-
void CPUFunction::beforeRun(const Context &ctx) {
37+
void CPUFunction::loadPlaceholders(Context *ctx,
38+
uint8_t *baseMutableWeightVarsAddress) {
5339
// Copy Placeholders into allocated memory.
54-
for (auto PH : ctx.pairs()) {
40+
for (auto PH : ctx->pairs()) {
5541
auto payload = PH.second->getUnsafePtr();
5642
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first);
5743
auto addr = symbolInfo.offset;
5844
auto numBytes = symbolInfo.size;
5945
// copy PH to allocated memory.
60-
memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes);
46+
memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes);
6147
}
6248
}
6349

64-
void CPUFunction::afterRun(const Context &ctx) {
50+
void CPUFunction::updatePlaceholders(Context *ctx,
51+
uint8_t *baseMutableWeightVarsAddress) {
6552
// Copy placeholders from device back into context.
66-
for (auto PH : ctx.pairs()) {
53+
for (auto PH : ctx->pairs()) {
6754
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first);
68-
auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset;
55+
auto payload = baseMutableWeightVarsAddress + symbolInfo.offset;
6956
auto numBytes = symbolInfo.size;
7057
auto addr = PH.second->getUnsafePtr();
7158
// copy PH from allocated memory.
7259
memcpy(addr, payload, numBytes);
7360
}
7461
}
7562

76-
void CPUFunction::tearDownRuns() {
77-
if (baseMutableWeightVarsAddress_) {
78-
alignedFree(baseMutableWeightVarsAddress_);
79-
baseMutableWeightVarsAddress_ = nullptr;
63+
void CPUFunction::execute(Context *ctx) {
64+
/// Base address for Activations memory block.
65+
uint8_t *baseActivationsAddress{nullptr};
66+
67+
/// Base address for Mutable weights memory block, Inputs and Outputs.
68+
uint8_t *baseMutableWeightVarsAddress{nullptr};
69+
70+
if (runtimeBundle_.getActivationsSize() != 0) {
71+
baseActivationsAddress = (uint8_t *)alignedAlloc(
72+
runtimeBundle_.getActivationsSize(), TensorAlignment);
8073
}
8174

82-
if (baseActivationsAddress_) {
83-
alignedFree(baseActivationsAddress_);
84-
baseActivationsAddress_ = nullptr;
75+
if (runtimeBundle_.getMutableWeightSize() != 0) {
76+
baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc(
77+
runtimeBundle_.getMutableWeightSize(), TensorAlignment);
8578
}
86-
runsSetup_ = false;
87-
}
8879

89-
void CPUFunction::execute() {
80+
loadPlaceholders(ctx, baseMutableWeightVarsAddress);
81+
9082
auto sym = JIT_->findSymbol("jitmain");
9183
assert(sym && "Unable to JIT the code!");
9284
using JitFuncType =
@@ -95,9 +87,14 @@ void CPUFunction::execute() {
9587
auto address = sym.getAddress();
9688
if (address) {
9789
JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get());
98-
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_,
99-
baseActivationsAddress_);
90+
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress,
91+
baseActivationsAddress);
10092
} else {
10193
GLOW_ASSERT(false && "Error getting address.");
10294
}
95+
96+
updatePlaceholders(ctx, baseMutableWeightVarsAddress);
97+
98+
alignedFree(baseMutableWeightVarsAddress);
99+
alignedFree(baseActivationsAddress);
103100
}

lib/Backends/CPU/CPUFunction.h

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,6 @@ class CPUFunction final : public CompiledFunction {
2828
/// initializes the LLVM backends.
2929
std::unique_ptr<llvm::orc::GlowJIT> JIT_;
3030

31-
/// Base address for Activations memory block.
32-
uint8_t *baseActivationsAddress_{};
33-
34-
/// Base address for Mutable weights memory block, Inputs and Outputs.
35-
uint8_t *baseMutableWeightVarsAddress_{};
36-
3731
public:
3832
/// Ctor.
3933
CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT,
@@ -42,24 +36,19 @@ class CPUFunction final : public CompiledFunction {
4236
/// Collects constants for runtime.
4337
void collectConstants(IRFunction *F);
4438

45-
/// Allocate Mutable buffers on device this includes Activations and
46-
/// Placeholders.
47-
void setupRuns() override;
48-
49-
/// Copy Input Placeholder data to position.
50-
void beforeRun(const Context &ctx) override;
51-
52-
/// Copy Outputs to Placeholders in \p ctx.
53-
void afterRun(const Context &ctx) override;
54-
55-
/// Final cleanup, free all allocations.
56-
void tearDownRuns() override;
57-
5839
/// \name CompiledFunction interface
5940
///@{
6041
~CPUFunction() override;
61-
void execute() override;
42+
void execute(Context *ctx) override;
6243
///@}
44+
private:
45+
/// Load constant tensors from \p ctx into \p weightsAddress, as defined by
46+
/// the RuntimeBundle (pre-run).
47+
void loadPlaceholders(Context *ctx, uint8_t *weightsAddress);
48+
49+
/// Load weights from \p weightsAddress into applicable backing tensors in
50+
/// \p ctx, as defined by the RuntimeBundle (post-run).
51+
void updatePlaceholders(Context *ctx, uint8_t *weightsAddress);
6352
};
6453
} // end namespace glow
6554

lib/Backends/Interpreter/InterpreterFunction.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,26 @@
2222

2323
#include "llvm/Support/Casting.h"
2424

25+
#include "llvm/Support/raw_ostream.h"
2526
using namespace glow;
2627

2728
InterpreterFunction::InterpreterFunction(std::unique_ptr<IRFunction> F,
2829
const runtime::RuntimeBundle &bundle)
2930
: CompiledFunction(bundle), F_(std::move(F)) {}
3031

3132
InterpreterFunction::~InterpreterFunction() {
32-
// Delete the tensors that are owned by this backend.
33-
for (const auto &p : tensors_) {
33+
for (const auto &p : constants_) {
3434
delete p.second;
3535
}
36-
tensors_.clear();
37-
externalTensors_.clear();
36+
constants_.clear();
37+
3838
alignedFree(runtimeBundle_.getConstants());
3939
tearDownRuns();
4040
}
4141

4242
void InterpreterFunction::collectConstants(IRFunction *F) {
4343
runtimeBundle_.collectConstants(F);
44-
}
45-
46-
void InterpreterFunction::setupRuns() {
47-
if (!runsSetup_) {
44+
if (constants_.empty()) {
4845
if (runtimeBundle_.getConstantWeightSize()) {
4946
for (const auto &v : F_->getGraph()->getParent()->getConstants()) {
5047
auto symbolInfo = runtimeBundle_.getSymbolInfo(v);
@@ -53,36 +50,27 @@ void InterpreterFunction::setupRuns() {
5350
constants_.emplace(std::string(v->getName()), tensor);
5451
}
5552
}
56-
runsSetup_ = true;
57-
}
58-
}
59-
60-
void InterpreterFunction::beforeRun(const Context &ctx) {
61-
// Register the concrete tensors that back the placeholder tensors.
62-
for (auto &ph : ctx.pairs()) {
63-
auto *w = F_->getWeightForNode(ph.first);
64-
assert(!externalTensors_.count(w) && "The tensor is already registered");
65-
externalTensors_[w] = ph.second;
6653
}
6754
}
6855

69-
void InterpreterFunction::afterRun(const Context &ctx) {
70-
// Remove the concrete tensors that back the placeholder tensors.
71-
for (auto &ph : ctx.pairs()) {
72-
auto *w = F_->getWeightForNode(ph.first);
73-
externalTensors_.erase(w);
56+
void InterpreterFunction::execute(Context *ctx) {
57+
if (constants_.empty()) {
58+
collectConstants(F_.get());
7459
}
60+
BoundInterpreterFunction boundFunc(constants_);
61+
boundFunc.execute(F_.get(), ctx);
7562
}
7663

77-
void InterpreterFunction::tearDownRuns() {
78-
for (const auto &p : constants_) {
64+
BoundInterpreterFunction::~BoundInterpreterFunction() {
65+
// Delete the tensors that are owned by this backend.
66+
for (const auto &p : tensors_) {
7967
delete p.second;
8068
}
81-
constants_.clear();
82-
runsSetup_ = false;
69+
tensors_.clear();
70+
externalTensors_.clear();
8371
}
8472

85-
Tensor *InterpreterFunction::getTensor(const Value *v) const {
73+
Tensor *BoundInterpreterFunction::getTensor(const Value *v) const {
8674
auto it = tensors_.find(v);
8775
if (it != tensors_.end()) {
8876
return it->second;
@@ -97,7 +85,7 @@ Tensor *InterpreterFunction::getTensor(const Value *v) const {
9785
return ie->second;
9886
}
9987

100-
Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) {
88+
Tensor *BoundInterpreterFunction::getOrCreateTensor(const Value *v) {
10189
auto ie = externalTensors_.find(v);
10290
if (ie != externalTensors_.end()) {
10391
return ie->second;
@@ -117,9 +105,8 @@ Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) {
117105
return it->second;
118106
}
119107

120-
Tensor *
121-
InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src,
122-
llvm::ArrayRef<size_t> offsets) {
108+
Tensor *BoundInterpreterFunction::getOrCreateUnownedTensor(
109+
const Value *v, const Value *src, llvm::ArrayRef<size_t> offsets) {
123110
assert(llvm::isa<TensorViewInst>(v) && "Expected a tensor view");
124111

125112
// Pick the tensor.
@@ -136,7 +123,7 @@ InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src,
136123
return T;
137124
}
138125

139-
void InterpreterFunction::deleteTensor(const Value *v) {
126+
void BoundInterpreterFunction::deleteTensor(const Value *v) {
140127
auto it = tensors_.find(v);
141128
if (it == tensors_.end()) {
142129
return;
@@ -146,7 +133,14 @@ void InterpreterFunction::deleteTensor(const Value *v) {
146133
tensors_.erase(it);
147134
}
148135

149-
void InterpreterFunction::execute() {
136+
void BoundInterpreterFunction::execute(IRFunction *F, Context *ctx) {
137+
// Register the concrete tensors that back the placeholder tensors.
138+
for (auto &ph : ctx->pairs()) {
139+
auto *w = F->getWeightForNode(ph.first);
140+
assert(!externalTensors_.count(w) && "The tensor is already registered");
141+
externalTensors_[w] = ph.second;
142+
}
143+
150144
// Do the forward pass.
151145
#define DEF_VALUE(CLASS, NAME)
152146
#define DEF_INSTR(CLASS, NAME) \
@@ -156,12 +150,18 @@ void InterpreterFunction::execute() {
156150
}
157151
#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME)
158152
// Dispatch the interpreter on each instruction in the program:
159-
for (const auto &I : F_->getInstrs()) {
153+
for (const auto &I : F->getInstrs()) {
160154
switch (I.getKind()) {
161155
#include "glow/AutoGenInstr.def"
162156

163157
default:
164158
llvm_unreachable("Invalid instruction.");
165159
}
166160
}
161+
162+
// Remove the concrete tensors that back the placeholder tensors.
163+
for (auto &ph : ctx->pairs()) {
164+
auto *w = F->getWeightForNode(ph.first);
165+
externalTensors_.erase(w);
166+
}
167167
}

0 commit comments

Comments
 (0)