diff --git a/cartographer_grpc/framework/client.h b/cartographer_grpc/framework/client.h new file mode 100644 index 0000000..a485f98 --- /dev/null +++ b/cartographer_grpc/framework/client.h @@ -0,0 +1,171 @@ +/* + * Copyright 2018 The Cartographer Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H +#define CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H + +#include "grpc++/grpc++.h" +#include "grpc++/impl/codegen/client_unary_call.h" +#include "grpc++/impl/codegen/sync_stream.h" + +namespace cartographer_grpc { +namespace framework { + +template +class Client { + public: + Client(std::shared_ptr channel) + : channel_(channel), + rpc_method_name_( + RpcHandlerInterface::Instantiate()->method_name()), + rpc_method_(rpc_method_name_.c_str(), + RpcType::value, + channel_) {} + + bool Read(typename RpcHandlerType::ResponseType* response) { + switch (rpc_method_.method_type()) { + case grpc::internal::RpcMethod::BIDI_STREAMING: + InstantiateClientReaderWriterIfNeeded(); + return client_reader_writer_->Read(response); + case grpc::internal::RpcMethod::SERVER_STREAMING: + CHECK(client_reader_); + return client_reader_->Read(response); + default: + LOG(FATAL) << "Not implemented."; + } + } + + bool Write(const typename RpcHandlerType::RequestType& request) { + switch (rpc_method_.method_type()) { + case grpc::internal::RpcMethod::NORMAL_RPC: + return MakeBlockingUnaryCall(request, &response_).ok(); + case grpc::internal::RpcMethod::CLIENT_STREAMING: + InstantiateClientWriterIfNeeded(); + return client_writer_->Write(request); + case grpc::internal::RpcMethod::BIDI_STREAMING: + InstantiateClientReaderWriterIfNeeded(); + return client_reader_writer_->Write(request); + case grpc::internal::RpcMethod::SERVER_STREAMING: + InstantiateClientReader(request); + return true; + } + LOG(FATAL) << "Not reached."; + } + + bool WritesDone() { + switch (rpc_method_.method_type()) { + case grpc::internal::RpcMethod::CLIENT_STREAMING: + InstantiateClientWriterIfNeeded(); + return client_writer_->WritesDone(); + case grpc::internal::RpcMethod::BIDI_STREAMING: + InstantiateClientReaderWriterIfNeeded(); + return client_reader_writer_->WritesDone(); + default: + LOG(FATAL) << "Not implemented."; + } + } + + grpc::Status Finish() { + switch (rpc_method_.method_type()) { + case grpc::internal::RpcMethod::CLIENT_STREAMING: + InstantiateClientWriterIfNeeded(); + return client_writer_->Finish(); + case grpc::internal::RpcMethod::BIDI_STREAMING: + InstantiateClientReaderWriterIfNeeded(); + return client_reader_writer_->Finish(); + case grpc::internal::RpcMethod::SERVER_STREAMING: + CHECK(client_reader_); + return client_reader_->Finish(); + default: + LOG(FATAL) << "Not implemented."; + } + } + + const typename RpcHandlerType::ResponseType& response() { + CHECK(rpc_method_.method_type() == grpc::internal::RpcMethod::NORMAL_RPC || + rpc_method_.method_type() == + grpc::internal::RpcMethod::CLIENT_STREAMING); + return response_; + } + + private: + void InstantiateClientWriterIfNeeded() { + CHECK_EQ(rpc_method_.method_type(), + grpc::internal::RpcMethod::CLIENT_STREAMING); + if (!client_writer_) { + client_writer_.reset( + grpc::internal::ClientWriterFactory< + typename RpcHandlerType::RequestType>::Create(channel_.get(), + rpc_method_, + &client_context_, + &response_)); + } + } + + void InstantiateClientReaderWriterIfNeeded() { + CHECK_EQ(rpc_method_.method_type(), + grpc::internal::RpcMethod::BIDI_STREAMING); + if (!client_reader_writer_) { + client_reader_writer_.reset( + grpc::internal::ClientReaderWriterFactory< + typename RpcHandlerType::RequestType, + typename RpcHandlerType::ResponseType>::Create(channel_.get(), + rpc_method_, + &client_context_)); + } + } + + void InstantiateClientReader( + const typename RpcHandlerType::RequestType& request) { + CHECK_EQ(rpc_method_.method_type(), + grpc::internal::RpcMethod::SERVER_STREAMING); + client_reader_.reset( + grpc::internal::ClientReaderFactory< + typename RpcHandlerType::ResponseType>::Create(channel_.get(), + rpc_method_, + &client_context_, + request)); + } + + grpc::Status MakeBlockingUnaryCall( + const typename RpcHandlerType::RequestType& request, + typename RpcHandlerType::ResponseType* response) { + CHECK_EQ(rpc_method_.method_type(), grpc::internal::RpcMethod::NORMAL_RPC); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpc_method_, &client_context_, request, response); + } + + std::shared_ptr channel_; + grpc::ClientContext client_context_; + const std::string rpc_method_name_; + const ::grpc::internal::RpcMethod rpc_method_; + + std::unique_ptr> + client_writer_; + std::unique_ptr< + grpc::ClientReaderWriter> + client_reader_writer_; + std::unique_ptr> + client_reader_; + typename RpcHandlerType::ResponseType response_; +}; + +} // namespace framework +} // namespace cartographer_grpc + +#endif // CARTOGRAPHER_GRPC_FRAMEWORK_CLIENT_H diff --git a/cartographer_grpc/framework/rpc_handler_interface.h b/cartographer_grpc/framework/rpc_handler_interface.h index 368a16e..1c6a5ff 100644 --- a/cartographer_grpc/framework/rpc_handler_interface.h +++ b/cartographer_grpc/framework/rpc_handler_interface.h @@ -17,6 +17,7 @@ #ifndef CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H #define CARTOGRAPHER_GRPC_FRAMEWORK_RPC_HANDLER_INTERFACE_H_H +#include "cartographer/common/make_unique.h" #include "cartographer_grpc/framework/execution_context.h" #include "google/protobuf/message.h" #include "grpc++/grpc++.h" @@ -40,6 +41,10 @@ class RpcHandlerInterface { const ::google::protobuf::Message* request) = 0; virtual void OnReadsDone(){}; virtual void OnFinish(){}; + template + static std::unique_ptr Instantiate() { + return cartographer::common::make_unique(); + } }; using RpcHandlerFactory = std::function( diff --git a/cartographer_grpc/framework/server.h b/cartographer_grpc/framework/server.h index f44bd4d..647349a 100644 --- a/cartographer_grpc/framework/server.h +++ b/cartographer_grpc/framework/server.h @@ -57,7 +57,8 @@ class Server { template void RegisterHandler() { - std::string method_full_name = GetMethodFullName(); + std::string method_full_name = + RpcHandlerInterface::Instantiate()->method_name(); std::string service_full_name; std::string method_name; std::tie(service_full_name, method_name) = @@ -83,11 +84,6 @@ class Server { private: using ServiceInfo = std::map; - template - std::string GetMethodFullName() { - auto handler = cartographer::common::make_unique(); - return handler->method_name(); - } std::tuple ParseMethodFullName(const std::string& method_full_name); diff --git a/cartographer_grpc/framework/server_test.cc b/cartographer_grpc/framework/server_test.cc index 5c5e6bc..5a97b59 100644 --- a/cartographer_grpc/framework/server_test.cc +++ b/cartographer_grpc/framework/server_test.cc @@ -18,6 +18,7 @@ #include +#include "cartographer_grpc/framework/client.h" #include "cartographer_grpc/framework/execution_context.h" #include "cartographer_grpc/framework/proto/math_service.grpc.pb.h" #include "cartographer_grpc/framework/proto/math_service.pb.h" @@ -156,13 +157,10 @@ class ServerTest : public ::testing::Test { client_channel_ = grpc::CreateChannel(kServerAddress, grpc::InsecureChannelCredentials()); - stub_ = proto::Math::NewStub(client_channel_); } std::unique_ptr server_; std::shared_ptr client_channel_; - std::unique_ptr stub_; - grpc::ClientContext client_context_; }; TEST_F(ServerTest, StartAndStopServerTest) { @@ -175,18 +173,15 @@ TEST_F(ServerTest, ProcessRpcStreamTest) { cartographer::common::make_unique()); server_->Start(); - proto::GetSumResponse result; - std::unique_ptr> writer( - stub_->GetSum(&client_context_, &result)); + Client client(client_channel_); for (int i = 0; i < 3; ++i) { proto::GetSumRequest request; request.set_input(i); - EXPECT_TRUE(writer->Write(request)); + EXPECT_TRUE(client.Write(request)); } - writer->WritesDone(); - grpc::Status status = writer->Finish(); - EXPECT_TRUE(status.ok()); - EXPECT_EQ(result.output(), 33); + EXPECT_TRUE(client.WritesDone()); + EXPECT_TRUE(client.Finish().ok()); + EXPECT_EQ(client.response().output(), 33); server_->Shutdown(); } @@ -194,12 +189,11 @@ TEST_F(ServerTest, ProcessRpcStreamTest) { TEST_F(ServerTest, ProcessUnaryRpcTest) { server_->Start(); - proto::GetSquareResponse result; + Client client(client_channel_); proto::GetSquareRequest request; request.set_input(11); - grpc::Status status = stub_->GetSquare(&client_context_, request, &result); - EXPECT_TRUE(status.ok()); - EXPECT_EQ(result.output(), 121); + EXPECT_TRUE(client.Write(request)); + EXPECT_EQ(client.response().output(), 121); server_->Shutdown(); } @@ -207,22 +201,21 @@ TEST_F(ServerTest, ProcessUnaryRpcTest) { TEST_F(ServerTest, ProcessBidiStreamingRpcTest) { server_->Start(); - auto reader_writer = stub_->GetRunningSum(&client_context_); + Client client(client_channel_); for (int i = 0; i < 3; ++i) { proto::GetSumRequest request; request.set_input(i); - EXPECT_TRUE(reader_writer->Write(request)); + EXPECT_TRUE(client.Write(request)); } - reader_writer->WritesDone(); + client.WritesDone(); proto::GetSumResponse response; - std::list expected_responses = {0, 0, 1, 1, 3, 3}; - while (reader_writer->Read(&response)) { + while (client.Read(&response)) { EXPECT_EQ(expected_responses.front(), response.output()); expected_responses.pop_front(); } EXPECT_TRUE(expected_responses.empty()); - EXPECT_TRUE(reader_writer->Finish().ok()); + EXPECT_TRUE(client.Finish().ok()); server_->Shutdown(); } @@ -232,10 +225,6 @@ TEST_F(ServerTest, WriteFromOtherThread) { cartographer::common::make_unique()); server_->Start(); - proto::GetEchoResponse result; - proto::GetEchoRequest request; - request.set_input(13); - Server* server = server_.get(); std::thread response_thread([server]() { std::future responder_future = @@ -245,10 +234,12 @@ TEST_F(ServerTest, WriteFromOtherThread) { CHECK(responder()); }); - grpc::Status status = stub_->GetEcho(&client_context_, request, &result); + Client client(client_channel_); + proto::GetEchoRequest request; + request.set_input(13); + EXPECT_TRUE(client.Write(request)); response_thread.join(); - EXPECT_TRUE(status.ok()); - EXPECT_EQ(result.output(), 13); + EXPECT_EQ(client.response().output(), 13); server_->Shutdown(); } @@ -256,17 +247,18 @@ TEST_F(ServerTest, WriteFromOtherThread) { TEST_F(ServerTest, ProcessServerStreamingRpcTest) { server_->Start(); + Client client(client_channel_); proto::GetSequenceRequest request; request.set_input(12); - auto reader = stub_->GetSequence(&client_context_, request); + client.Write(request); proto::GetSequenceResponse response; for (int i = 0; i < 12; ++i) { - EXPECT_TRUE(reader->Read(&response)); + EXPECT_TRUE(client.Read(&response)); EXPECT_EQ(response.output(), i); } - EXPECT_FALSE(reader->Read(&response)); - EXPECT_TRUE(reader->Finish().ok()); + EXPECT_FALSE(client.Read(&response)); + EXPECT_TRUE(client.Finish().ok()); server_->Shutdown(); }