diff --git a/cartographer_grpc/framework/execution_context.h b/cartographer_grpc/framework/execution_context.h new file mode 100644 index 0000000..060c30c --- /dev/null +++ b/cartographer_grpc/framework/execution_context.h @@ -0,0 +1,61 @@ +/* + * Copyright 2017 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H +#define CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H + +#include "cartographer/common/mutex.h" +#include "glog/logging.h" + +namespace cartographer_grpc { +namespace framework { + +// Implementations of this class allow RPC handlers to share state among one +// another. Using Server::SetExecutionContext(...) a server-wide +// 'ExecutionContext' can be specified. This 'ExecutionContext' can be retrieved +// by all implementations of 'RpcHandler' by calling +// 'RpcHandler::GetContext()'. +class ExecutionContext { + public: + // This non-movable, non-copyable class is used to broker access from various + // RPC handlers to the shared 'ExecutionContext'. Handles automatically lock + // the context they point to. + template + class Synchronized { + public: + ContextType* operator->() { + return static_cast(execution_context_); + } + Synchronized(cartographer::common::Mutex* lock, + ExecutionContext* execution_context) + : locker_(lock), execution_context_(execution_context) {} + Synchronized(const Synchronized&) = delete; + Synchronized(Synchronized&&) = delete; + + private: + cartographer::common::MutexLocker locker_; + ExecutionContext* execution_context_; + }; + cartographer::common::Mutex* lock() { return &lock_; } + + private: + cartographer::common::Mutex lock_; +}; + +} // namespace framework +} // namespace cartographer_grpc + +#endif // CARTOGRAPHER_GRPC_FRAMEWORK_EXECUTION_CONTEXT_H diff --git a/cartographer_grpc/framework/rpc.cc b/cartographer_grpc/framework/rpc.cc index f616e8a..b8ab2ff 100644 --- a/cartographer_grpc/framework/rpc.cc +++ b/cartographer_grpc/framework/rpc.cc @@ -25,16 +25,18 @@ namespace framework { Rpc::Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, + ExecutionContext* execution_context, const RpcHandlerInfo& rpc_handler_info, Service* service) : method_index_(method_index), server_completion_queue_(server_completion_queue), + execution_context_(execution_context), rpc_handler_info_(rpc_handler_info), service_(service), new_connection_event_{Event::NEW_CONNECTION, this, false}, read_event_{Event::READ, this, false}, write_event_{Event::WRITE, this, false}, done_event_{Event::DONE, this, false}, - handler_(rpc_handler_info_.rpc_handler_factory(this)) { + handler_(rpc_handler_info_.rpc_handler_factory(this, execution_context)) { InitializeReadersAndWriters(rpc_handler_info_.rpc_type); // Initialize the prototypical request and response messages. @@ -48,7 +50,8 @@ Rpc::Rpc(int method_index, std::unique_ptr Rpc::Clone() { return cartographer::common::make_unique( - method_index_, server_completion_queue_, rpc_handler_info_, service_); + method_index_, server_completion_queue_, execution_context_, + rpc_handler_info_, service_); } void Rpc::OnRequest() { handler_->OnRequestInternal(request_.get()); } diff --git a/cartographer_grpc/framework/rpc.h b/cartographer_grpc/framework/rpc.h index be89118..a8a520f 100644 --- a/cartographer_grpc/framework/rpc.h +++ b/cartographer_grpc/framework/rpc.h @@ -21,6 +21,7 @@ #include #include "cartographer/common/mutex.h" +#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/rpc_handler_interface.h" #include "google/protobuf/message.h" #include "grpc++/grpc++.h" @@ -47,6 +48,7 @@ class Rpc { }; Rpc(int method_index, ::grpc::ServerCompletionQueue* server_completion_queue, + ExecutionContext* execution_context, const RpcHandlerInfo& rpc_handler_info, Service* service); std::unique_ptr Clone(); void OnRequest(); @@ -69,6 +71,7 @@ class Rpc { int method_index_; ::grpc::ServerCompletionQueue* server_completion_queue_; + ExecutionContext* execution_context_; RpcHandlerInfo rpc_handler_info_; Service* service_; ::grpc::ServerContext server_context_; diff --git a/cartographer_grpc/framework/rpc_handler.h b/cartographer_grpc/framework/rpc_handler.h index 7680fc8..aab415c 100644 --- a/cartographer_grpc/framework/rpc_handler.h +++ b/cartographer_grpc/framework/rpc_handler.h @@ -17,6 +17,7 @@ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_H +#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc_handler_interface.h" #include "cartographer_grpc/framework/type_traits.h" @@ -35,6 +36,9 @@ class RpcHandler : public RpcHandlerInterface { using RequestType = StripStream; using ResponseType = StripStream; + void SetExecutionContext(ExecutionContext* execution_context) { + execution_context_ = execution_context; + } void SetRpc(Rpc* rpc) override { rpc_ = rpc; } void OnRequestInternal(const ::google::protobuf::Message* request) override { DCHECK(dynamic_cast(request)); @@ -44,9 +48,14 @@ class RpcHandler : public RpcHandlerInterface { void Send(std::unique_ptr response) { rpc_->Write(std::move(response)); } + template + ExecutionContext::Synchronized GetContext() { + return {execution_context_->lock(), execution_context_}; + } private: Rpc* rpc_; + ExecutionContext* execution_context_; }; } // namespace framework diff --git a/cartographer_grpc/framework/rpc_handler_interface.h b/cartographer_grpc/framework/rpc_handler_interface.h index fb36971..ab0c21b 100644 --- a/cartographer_grpc/framework/rpc_handler_interface.h +++ b/cartographer_grpc/framework/rpc_handler_interface.h @@ -17,6 +17,7 @@ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H +#include "cartographer_grpc/framework/execution_context.h" #include "google/protobuf/message.h" #include "grpc++/grpc++.h" @@ -27,14 +28,15 @@ class Rpc; class RpcHandlerInterface { public: virtual ~RpcHandlerInterface() = default; + virtual void SetExecutionContext(ExecutionContext* execution_context) = 0; virtual void SetRpc(Rpc* rpc) = 0; virtual void OnRequestInternal( const ::google::protobuf::Message* request) = 0; virtual void OnReadsDone() = 0; }; -using RpcHandlerFactory = - std::function(Rpc*)>; +using RpcHandlerFactory = std::function( + Rpc*, ExecutionContext*)>; struct RpcHandlerInfo { const google::protobuf::Descriptor* request_descriptor; diff --git a/cartographer_grpc/framework/server.cc b/cartographer_grpc/framework/server.cc index 507af64..ef0669e 100644 --- a/cartographer_grpc/framework/server.cc +++ b/cartographer_grpc/framework/server.cc @@ -77,7 +77,8 @@ void Server::Start() { // Start serving all services on all completion queues. for (auto& service : services_) { - service.second.StartServing(completion_queue_threads_); + service.second.StartServing(completion_queue_threads_, + execution_context_.get()); } // Start threads to process all completion queues. @@ -107,5 +108,13 @@ void Server::Shutdown() { LOG(INFO) << "Shutdown complete."; } +void Server::SetExecutionContext( + std::unique_ptr execution_context) { + // After the server has been started the 'ExecutionHandle' cannot be changed + // anymore. + CHECK(!server_); + execution_context_ = std::move(execution_context); +} + } // namespace framework } // namespace cartographer_grpc diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index 2f68f27..a339965 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -24,6 +24,7 @@ #include "cartographer/common/make_unique.h" #include "cartographer_grpc/framework/completion_queue_thread.h" +#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/rpc_handler.h" #include "cartographer_grpc/framework/service.h" #include "grpc++/grpc++.h" @@ -57,10 +58,11 @@ class Server { RpcHandlerInfo{ RpcHandlerType::RequestType::default_instance().GetDescriptor(), RpcHandlerType::ResponseType::default_instance().GetDescriptor(), - [](Rpc* const rpc) { + [](Rpc* const rpc, ExecutionContext* const execution_context) { std::unique_ptr rpc_handler = cartographer::common::make_unique(); rpc_handler->SetRpc(rpc); + rpc_handler->SetExecutionContext(execution_context); return rpc_handler; }, RpcType execution_context); + private: Server(const Options& options); Server(const Server&) = delete; @@ -104,6 +109,10 @@ class Server { // Map of service names to services. std::map services_; + + // A context object that is shared between all implementations of + // 'RpcHandler'. + std::unique_ptr execution_context_; }; } // namespace framework diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 02ad2a1..2c77ec2 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -16,6 +16,7 @@ #include "cartographer_grpc/framework/server.h" +#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" #include "cartographer_grpc/framework/proto/math_service.pb.h" #include "cartographer_grpc/framework/rpc_handler.h" @@ -27,10 +28,16 @@ namespace cartographer_grpc { namespace framework { namespace { +class MathServerContext : public ExecutionContext { + public: + int additional_increment() { return 10; } +}; + class GetServerOptionsHandler : public RpcHandler, proto::GetSumResponse> { public: void OnRequest(const proto::GetSumRequest& request) override { + sum_ += GetContext()->additional_increment(); sum_ += request.input(); } @@ -70,6 +77,8 @@ TEST_F(ServerTest, StartAndStopServerTest) { } TEST_F(ServerTest, ProcessRpcStreamTest) { + server_->SetExecutionContext( + cartographer::common::make_unique()); server_->Start(); auto channel = @@ -87,7 +96,7 @@ TEST_F(ServerTest, ProcessRpcStreamTest) { writer->WritesDone(); grpc::Status status = writer->Finish(); EXPECT_TRUE(status.ok()); - EXPECT_EQ(result.output(), 3); + EXPECT_EQ(result.output(), 33); server_->Shutdown(); } diff --git a/cartographer_grpc/framework/service.cc b/cartographer_grpc/framework/service.cc index 543c01e..76e0d0c 100644 --- a/cartographer_grpc/framework/service.cc +++ b/cartographer_grpc/framework/service.cc @@ -37,12 +37,13 @@ Service::Service(const std::string& service_name, } void Service::StartServing( - std::vector& completion_queue_threads) { + std::vector& completion_queue_threads, + ExecutionContext* execution_context) { int i = 0; for (const auto& rpc_handler_info : rpc_handler_infos_) { for (auto& completion_queue_thread : completion_queue_threads) { Rpc* rpc = active_rpcs_.Add(cartographer::common::make_unique( - i, completion_queue_thread.completion_queue(), + i, completion_queue_thread.completion_queue(), execution_context, rpc_handler_info.second, this)); rpc->RequestNextMethodInvocation(); } diff --git a/cartographer_grpc/framework/service.h b/cartographer_grpc/framework/service.h index 00c9909..e54af0c 100644 --- a/cartographer_grpc/framework/service.h +++ b/cartographer_grpc/framework/service.h @@ -18,6 +18,7 @@ #define CARTOGRAPHER_GRPC_FRAMEWORK_SERVICE_H #include "cartographer_grpc/framework/completion_queue_thread.h" +#include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/rpc.h" #include "cartographer_grpc/framework/rpc_handler.h" #include "grpc++/impl/codegen/service_type.h" @@ -35,7 +36,8 @@ class Service : public ::grpc::Service { Service(const std::string& service_name, const std::map& rpc_handlers); - void StartServing(std::vector& completion_queues); + void StartServing(std::vector& completion_queues, + ExecutionContext* execution_context); void HandleEvent(Rpc::Event event, Rpc* rpc, bool ok); void StopServing();