From f51e4f4f05e84d7cfebbc07e072fdd3c72753d7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20Sch=C3=BCtte?= Date: Tue, 24 Apr 2018 15:57:44 +0200 Subject: [PATCH] Introduce Task (#1066) This introduces the new class common::Task. A Task can have dependencies to other tasks, notify ThreadPoolInterface when all its dependencies are fulfilled, and can be executed in the background. --- .../testing/thread_pool_for_testing.h | 11 +- cartographer/common/task.cc | 107 ++++++++++ cartographer/common/task.h | 75 +++++++ cartographer/common/task_test.cc | 188 ++++++++++++++++++ cartographer/common/thread_pool.cc | 6 + cartographer/common/thread_pool.h | 30 +++ 6 files changed, 416 insertions(+), 1 deletion(-) create mode 100644 cartographer/common/task.cc create mode 100644 cartographer/common/task.h create mode 100644 cartographer/common/task_test.cc diff --git a/cartographer/common/internal/testing/thread_pool_for_testing.h b/cartographer/common/internal/testing/thread_pool_for_testing.h index 4dc39ff..8a8778c 100644 --- a/cartographer/common/internal/testing/thread_pool_for_testing.h +++ b/cartographer/common/internal/testing/thread_pool_for_testing.h @@ -34,7 +34,16 @@ class ThreadPoolForTesting : public ThreadPoolInterface { ThreadPoolForTesting(); ~ThreadPoolForTesting(); - void Schedule(const std::function &work_item) override; + void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { + LOG(FATAL) << "not implemented"; + } + + void Schedule(const std::function& work_item) override; + std::weak_ptr Schedule(std::unique_ptr task) + EXCLUDES(mutex_) override { + LOG(FATAL) << "not implemented"; + } + void WaitUntilIdle(); private: diff --git a/cartographer/common/task.cc b/cartographer/common/task.cc new file mode 100644 index 0000000..7202268 --- /dev/null +++ b/cartographer/common/task.cc @@ -0,0 +1,107 @@ +/* + * Copyright 2018 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cartographer/common/task.h" + +namespace cartographer { +namespace common { + +Task::~Task() { + // TODO(gaschler): Relax some checks after testing. + if (state_ != NEW && state_ != COMPLETED) { + LOG(WARNING) << "Delete Task between dispatch and completion."; + } +} + +Task::State Task::GetState() { + MutexLocker locker(&mutex_); + return state_; +} + +void Task::SetWorkItem(const WorkItem& work_item) { + MutexLocker locker(&mutex_); + CHECK_EQ(state_, NEW); + work_item_ = work_item; +} + +void Task::AddDependency(std::weak_ptr dependency) { + std::shared_ptr shared_dependency; + { + MutexLocker locker(&mutex_); + CHECK_EQ(state_, NEW); + if (shared_dependency = dependency.lock()) { + ++uncompleted_dependencies_; + } + } + if (shared_dependency) { + shared_dependency->AddDependentTask(this); + } +} + +void Task::SetThreadPool(ThreadPoolInterface* thread_pool) { + MutexLocker locker(&mutex_); + CHECK_EQ(state_, NEW); + state_ = DISPATCHED; + thread_pool_to_notify_ = thread_pool; + if (uncompleted_dependencies_ == 0) { + state_ = DEPENDENCIES_COMPLETED; + CHECK(thread_pool_to_notify_); + thread_pool_to_notify_->NotifyDependenciesCompleted(this); + } +} + +void Task::AddDependentTask(Task* dependent_task) { + MutexLocker locker(&mutex_); + if (state_ == COMPLETED) { + dependent_task->OnDependenyCompleted(); + return; + } + bool inserted = dependent_tasks_.insert(dependent_task).second; + CHECK(inserted) << "Given dependency is already a dependency."; +} + +void Task::OnDependenyCompleted() { + MutexLocker locker(&mutex_); + CHECK(state_ == NEW || state_ == DISPATCHED); + --uncompleted_dependencies_; + if (uncompleted_dependencies_ == 0 && state_ == DISPATCHED) { + state_ = DEPENDENCIES_COMPLETED; + CHECK(thread_pool_to_notify_); + thread_pool_to_notify_->NotifyDependenciesCompleted(this); + } +} + +void Task::Execute() { + { + MutexLocker locker(&mutex_); + CHECK_EQ(state_, DEPENDENCIES_COMPLETED); + state_ = RUNNING; + } + + // Execute the work item. + if (work_item_) { + work_item_(); + } + + MutexLocker locker(&mutex_); + state_ = COMPLETED; + for (Task* dependent_task : dependent_tasks_) { + dependent_task->OnDependenyCompleted(); + } +} + +} // namespace common +} // namespace cartographer diff --git a/cartographer/common/task.h b/cartographer/common/task.h new file mode 100644 index 0000000..360989d --- /dev/null +++ b/cartographer/common/task.h @@ -0,0 +1,75 @@ +/* + * Copyright 2018 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CARTOGRAPHER_COMMON_TASK_H_ +#define CARTOGRAPHER_COMMON_TASK_H_ + +#include + +#include "cartographer/common/mutex.h" +#include "glog/logging.h" +#include "thread_pool.h" + +namespace cartographer { +namespace common { + +class ThreadPoolInterface; + +class Task { + public: + friend class ThreadPoolInterface; + + using WorkItem = std::function; + enum State { NEW, DISPATCHED, DEPENDENCIES_COMPLETED, RUNNING, COMPLETED }; + + ~Task(); + + State GetState() EXCLUDES(mutex_); + + // State must be 'NEW'. + void SetWorkItem(const WorkItem& work_item) EXCLUDES(mutex_); + + // State must be 'NEW'. 'dependency' may be nullptr, in which case it is + // assumed completed. + void AddDependency(std::weak_ptr dependency) EXCLUDES(mutex_); + + private: + // Allowed in all states. + void AddDependentTask(Task* dependent_task); + + // State must be 'DEPENDENCIES_COMPLETED' and becomes 'COMPLETED'. + void Execute() EXCLUDES(mutex_); + + // State must be 'NEW' and becomes 'DISPATCHED' or 'DEPENDENCIES_COMPLETED'. + void SetThreadPool(ThreadPoolInterface* thread_pool) EXCLUDES(mutex_); + + // State must be 'NEW' or 'DISPATCHED'. If 'DISPATCHED', may become + // 'DEPENDENCIES_COMPLETED'. + void OnDependenyCompleted(); + + WorkItem work_item_ GUARDED_BY(mutex_); + ThreadPoolInterface* thread_pool_to_notify_ GUARDED_BY(mutex_) = nullptr; + State state_ GUARDED_BY(mutex_) = NEW; + unsigned int uncompleted_dependencies_ GUARDED_BY(mutex_) = 0; + std::set dependent_tasks_ GUARDED_BY(mutex_); + + Mutex mutex_; +}; + +} // namespace common +} // namespace cartographer + +#endif // CARTOGRAPHER_COMMON_TASK_H_ diff --git a/cartographer/common/task_test.cc b/cartographer/common/task_test.cc new file mode 100644 index 0000000..935b7a5 --- /dev/null +++ b/cartographer/common/task_test.cc @@ -0,0 +1,188 @@ +/* + * Copyright 2018 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cartographer/common/task.h" + +#include +#include + +#include "cartographer/common/make_unique.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace cartographer { +namespace common { +namespace { + +class MockCallback { + public: + MOCK_METHOD0(Run, void()); +}; + +class FakeThreadPool : public ThreadPoolInterface { + public: + void NotifyDependenciesCompleted(Task* task) { + auto it = tasks_not_ready_.find(task); + ASSERT_NE(it, tasks_not_ready_.end()); + task_queue_.push_back(it->second); + tasks_not_ready_.erase(it); + } + + void Schedule(const std::function& work_item) { + LOG(FATAL) << "not implemented"; + } + + std::weak_ptr Schedule(std::unique_ptr task) override { + auto it = + tasks_not_ready_.insert(std::make_pair(task.get(), std::move(task))); + EXPECT_TRUE(it.second); + SetThreadPool(it.first->first); + return it.first->second; + } + + void RunNext() { + ASSERT_GE(task_queue_.size(), 1); + Execute(task_queue_.front().get()); + task_queue_.pop_front(); + } + + bool IsEmpty() { return task_queue_.empty(); } + + private: + std::deque> task_queue_; + std::map> tasks_not_ready_; +}; + +class TaskTest : public ::testing::Test { + protected: + FakeThreadPool* thread_pool() { return &thread_pool_; } + FakeThreadPool thread_pool_; +}; + +TEST_F(TaskTest, RunTask) { + auto a = make_unique(); + MockCallback callback; + a->SetWorkItem([&callback]() { callback.Run(); }); + EXPECT_EQ(a->GetState(), Task::NEW); + auto shared_a = thread_pool()->Schedule(std::move(a)).lock(); + EXPECT_NE(shared_a, nullptr); + EXPECT_EQ(shared_a->GetState(), Task::DEPENDENCIES_COMPLETED); + EXPECT_CALL(callback, Run()).Times(1); + thread_pool()->RunNext(); + EXPECT_EQ(shared_a->GetState(), Task::COMPLETED); + EXPECT_TRUE(thread_pool()->IsEmpty()); +} + +TEST_F(TaskTest, RunTaskWithDependency) { + auto a = make_unique(); + auto b = make_unique(); + MockCallback callback_a; + a->SetWorkItem([&callback_a]() { callback_a.Run(); }); + MockCallback callback_b; + b->SetWorkItem([&callback_b]() { callback_b.Run(); }); + EXPECT_EQ(a->GetState(), Task::NEW); + EXPECT_EQ(b->GetState(), Task::NEW); + { + ::testing::InSequence dummy; + EXPECT_CALL(callback_a, Run()).Times(1); + EXPECT_CALL(callback_b, Run()).Times(1); + } + auto shared_a = thread_pool()->Schedule(std::move(a)).lock(); + EXPECT_NE(shared_a, nullptr); + b->AddDependency(shared_a); + auto shared_b = thread_pool()->Schedule(std::move(b)).lock(); + EXPECT_NE(shared_b, nullptr); + EXPECT_EQ(shared_b->GetState(), Task::DISPATCHED); + EXPECT_EQ(shared_a->GetState(), Task::DEPENDENCIES_COMPLETED); + thread_pool()->RunNext(); + EXPECT_EQ(shared_b->GetState(), Task::DEPENDENCIES_COMPLETED); + thread_pool()->RunNext(); + EXPECT_EQ(shared_a->GetState(), Task::COMPLETED); + EXPECT_EQ(shared_b->GetState(), Task::COMPLETED); +} + +TEST_F(TaskTest, RunTaskWithTwoDependency) { + /* c \ + * a --> b --> d + */ + auto a = make_unique(); + auto b = make_unique(); + auto c = make_unique(); + auto d = make_unique(); + MockCallback callback_a; + a->SetWorkItem([&callback_a]() { callback_a.Run(); }); + MockCallback callback_b; + b->SetWorkItem([&callback_b]() { callback_b.Run(); }); + MockCallback callback_c; + c->SetWorkItem([&callback_c]() { callback_c.Run(); }); + MockCallback callback_d; + d->SetWorkItem([&callback_d]() { callback_d.Run(); }); + EXPECT_CALL(callback_a, Run()).Times(1); + EXPECT_CALL(callback_b, Run()).Times(1); + EXPECT_CALL(callback_c, Run()).Times(1); + EXPECT_CALL(callback_d, Run()).Times(1); + auto shared_a = thread_pool()->Schedule(std::move(a)).lock(); + EXPECT_NE(shared_a, nullptr); + b->AddDependency(shared_a); + auto shared_b = thread_pool()->Schedule(std::move(b)).lock(); + EXPECT_NE(shared_b, nullptr); + auto shared_c = thread_pool()->Schedule(std::move(c)).lock(); + EXPECT_NE(shared_c, nullptr); + d->AddDependency(shared_b); + d->AddDependency(shared_c); + auto shared_d = thread_pool()->Schedule(std::move(d)).lock(); + EXPECT_NE(shared_d, nullptr); + EXPECT_EQ(shared_b->GetState(), Task::DISPATCHED); + EXPECT_EQ(shared_d->GetState(), Task::DISPATCHED); + thread_pool()->RunNext(); + EXPECT_EQ(shared_a->GetState(), Task::COMPLETED); + EXPECT_EQ(shared_b->GetState(), Task::DEPENDENCIES_COMPLETED); + EXPECT_EQ(shared_c->GetState(), Task::DEPENDENCIES_COMPLETED); + thread_pool()->RunNext(); + thread_pool()->RunNext(); + EXPECT_EQ(shared_b->GetState(), Task::COMPLETED); + EXPECT_EQ(shared_c->GetState(), Task::COMPLETED); + EXPECT_EQ(shared_d->GetState(), Task::DEPENDENCIES_COMPLETED); + thread_pool()->RunNext(); + EXPECT_EQ(shared_d->GetState(), Task::COMPLETED); +} + +TEST_F(TaskTest, RunWithCompletedDependency) { + auto a = make_unique(); + MockCallback callback_a; + a->SetWorkItem([&callback_a]() { callback_a.Run(); }); + auto shared_a = thread_pool()->Schedule(std::move(a)).lock(); + EXPECT_NE(shared_a, nullptr); + EXPECT_EQ(shared_a->GetState(), Task::DEPENDENCIES_COMPLETED); + EXPECT_CALL(callback_a, Run()).Times(1); + thread_pool()->RunNext(); + EXPECT_EQ(shared_a->GetState(), Task::COMPLETED); + auto b = make_unique(); + MockCallback callback_b; + b->SetWorkItem([&callback_b]() { callback_b.Run(); }); + b->AddDependency(shared_a); + EXPECT_EQ(b->GetState(), Task::NEW); + auto shared_b = thread_pool()->Schedule(std::move(b)).lock(); + EXPECT_NE(shared_b, nullptr); + EXPECT_EQ(shared_b->GetState(), Task::DEPENDENCIES_COMPLETED); + EXPECT_CALL(callback_b, Run()).Times(1); + thread_pool()->RunNext(); + EXPECT_EQ(shared_b->GetState(), Task::COMPLETED); +} + +} // namespace +} // namespace common +} // namespace cartographer diff --git a/cartographer/common/thread_pool.cc b/cartographer/common/thread_pool.cc index fdda166..d36c83a 100644 --- a/cartographer/common/thread_pool.cc +++ b/cartographer/common/thread_pool.cc @@ -26,6 +26,12 @@ namespace cartographer { namespace common { +void ThreadPoolInterface::Execute(Task* task) { task->Execute(); } + +void ThreadPoolInterface::SetThreadPool(Task* task) { + task->SetThreadPool(this); +} + ThreadPool::ThreadPool(int num_threads) { MutexLocker locker(&mutex_); for (int i = 0; i != num_threads; ++i) { diff --git a/cartographer/common/thread_pool.h b/cartographer/common/thread_pool.h index 7ed3508..2121629 100644 --- a/cartographer/common/thread_pool.h +++ b/cartographer/common/thread_pool.h @@ -19,17 +19,35 @@ #include #include +#include +#include #include #include #include "cartographer/common/mutex.h" +#include "cartographer/common/task.h" namespace cartographer { namespace common { +class Task; + class ThreadPoolInterface { public: + ThreadPoolInterface() {} + virtual ~ThreadPoolInterface() {} + // TODO(gaschler): Use Schedule(unique_ptr), then remove Schedule. virtual void Schedule(const std::function& work_item) = 0; + virtual std::weak_ptr Schedule(std::unique_ptr task) = 0; + + protected: + void Execute(Task* task); + void SetThreadPool(Task* task); + + private: + friend class Task; + + virtual void NotifyDependenciesCompleted(Task* task) = 0; }; // A fixed number of threads working on a work queue of work items. Adding a @@ -45,11 +63,23 @@ class ThreadPool : public ThreadPoolInterface { ThreadPool(const ThreadPool&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; + // TODO(gaschler): Remove all uses. void Schedule(const std::function& work_item) override; + // When the returned weak pointer is expired, 'task' has certainly completed, + // so dependants no longer need to add it as a dependency. + std::weak_ptr Schedule(std::unique_ptr task) + EXCLUDES(mutex_) override { + LOG(FATAL) << "not implemented"; + } + private: void DoWork(); + void NotifyDependenciesCompleted(Task* task) EXCLUDES(mutex_) override { + LOG(FATAL) << "not implemented"; + } + Mutex mutex_; bool running_ GUARDED_BY(mutex_) = true; std::vector pool_ GUARDED_BY(mutex_);