diff --git a/google/cloud/storage/internal/async/writer_connection_buffered.cc b/google/cloud/storage/internal/async/writer_connection_buffered.cc index 786a33b3b89cd..ec179e207b516 100644 --- a/google/cloud/storage/internal/async/writer_connection_buffered.cc +++ b/google/cloud/storage/internal/async/writer_connection_buffered.cc @@ -281,7 +281,8 @@ class AsyncWriterConnectionBufferedState return tmp; } - void OnQuery(std::unique_lock lk, std::int64_t persisted_size) { + void OnQuery(std::unique_lock lk, std::int64_t persisted_size, + bool is_resume = false) { if (persisted_size < buffer_offset_) { auto id = UploadId(lk); return SetError(std::move(lk), @@ -297,7 +298,22 @@ class AsyncWriterConnectionBufferedState } resend_buffer_.RemovePrefix(static_cast(n)); buffer_offset_ = persisted_size; - write_offset_ -= static_cast(n); + if (is_resume) { + // Since the buffer has been modified to start exactly at the point of the + // resume, the next write on this new stream should start from the + // beginning of this truncated buffer. + write_offset_ = 0; + } else { + // While rare, it is possible that n >= write_offset_ (i.e. the server has + // persisted more than we have sent) if, for example, multiple clients + // resume the same upload. If that is the case, all the bytes covered by + // write_offset_ have been flushed and we can reset it to 0. + if (static_cast(n) >= write_offset_) { + write_offset_ = 0; + } else { + write_offset_ -= static_cast(n); + } + } // If the buffer is small enough, collect all the handlers to notify them. auto const handlers = ClearHandlersIfEmpty(lk); // SetFlushed will release the lock before returning. @@ -382,7 +398,7 @@ class AsyncWriterConnectionBufferedState std::move(state))); } // Regular resume succeeded, object not finalized. Continue writing. - OnQuery(std::move(lk), absl::get(state)); + OnQuery(std::move(lk), absl::get(state), /*is_resume=*/true); } void SetFinalized(std::unique_lock lk, diff --git a/google/cloud/storage/internal/async/writer_connection_buffered_test.cc b/google/cloud/storage/internal/async/writer_connection_buffered_test.cc index 40758de3c17e6..ed01b4581a3f0 100644 --- a/google/cloud/storage/internal/async/writer_connection_buffered_test.cc +++ b/google/cloud/storage/internal/async/writer_connection_buffered_test.cc @@ -1266,6 +1266,67 @@ TEST(WriteConnectionBuffered, SetFinalizedIsIdempotent) { next.first.set_value(true); } +TEST(WriteConnectionBuffered, ResetWriteOffsetOnResume) { + AsyncSequencer sequencer; + auto mock = std::make_unique(); + auto* mock_ptr = mock.get(); + + EXPECT_CALL(*mock_ptr, UploadId).WillRepeatedly(Return("test-upload-id")); + EXPECT_CALL(*mock_ptr, PersistedState) + .WillOnce( + Return(MakePersistedState(0))); // Initial state: 0 bytes persisted. + + EXPECT_CALL(*mock_ptr, Write).WillOnce([&](auto) { + return sequencer.PushBack("Write").then([](auto f) { + if (!f.get()) return TransientError(); // This write will fail. + return Status{}; + }); + }); + + MockFactory mock_factory; + auto resumed_mock = std::make_unique(); + auto* resumed_mock_ptr = resumed_mock.get(); + + EXPECT_CALL(mock_factory, Call).WillOnce([&]() { + return sequencer.PushBack("Resume").then([&](auto) { + // The resumed connection reports that 1024 bytes have been persisted. + EXPECT_CALL(*resumed_mock_ptr, PersistedState) + .WillRepeatedly(Return(MakePersistedState(1024))); + // We expect the next write on the resumed stream to send the remaining + // 1024 bytes. If the write offset was not reset to 0, this size would be + // incorrect. + EXPECT_CALL(*resumed_mock_ptr, Write).WillOnce([&](auto payload) { + EXPECT_EQ(payload.size(), 1024); + return sequencer.PushBack("ResumedWrite").then([](auto) { + return Status{}; + }); + }); + return make_status_or(std::unique_ptr( + std::move(resumed_mock))); + }); + }); + + auto connection = MakeWriterConnectionBuffered( + mock_factory.AsStdFunction(), std::move(mock), TestOptions()); + + // Write a total of 2048 bytes. + auto write = connection->Write(TestPayload(2048)); + + auto next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "Write"); + next.first.set_value(false); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "Resume"); + next.first.set_value(true); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "ResumedWrite"); + next.first.set_value(true); + + EXPECT_STATUS_OK(write.get()); +} + } // namespace GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END } // namespace storage_internal diff --git a/google/cloud/storage/internal/async/writer_connection_resumed.cc b/google/cloud/storage/internal/async/writer_connection_resumed.cc index 3b860bef02a6e..5ae78d307bd59 100644 --- a/google/cloud/storage/internal/async/writer_connection_resumed.cc +++ b/google/cloud/storage/internal/async/writer_connection_resumed.cc @@ -317,7 +317,22 @@ class AsyncWriterConnectionResumedState } resend_buffer_.RemovePrefix(static_cast(n)); buffer_offset_ = persisted_size; - write_offset_ -= static_cast(n); + if (state_ == State::kResuming) { + // Since the buffer has been modified to start exactly at the point of the + // resume, the next write on this new stream should start from the + // beginning of this truncated buffer. + write_offset_ = 0; + } else { + // While rare, it is possible that n >= write_offset_ (i.e. the server has + // persisted more than we have sent) if, for example, multiple clients + // resume the same upload. If that is the case, all the bytes covered by + // write_offset_ have been flushed and we can reset it to 0. + if (static_cast(n) >= write_offset_) { + write_offset_ = 0; + } else { + write_offset_ -= static_cast(n); + } + } // If the buffer is small enough, collect all the handlers to notify them. auto const handlers = ClearHandlersIfEmpty(lk); state_ = State::kIdle; diff --git a/google/cloud/storage/internal/async/writer_connection_resumed_test.cc b/google/cloud/storage/internal/async/writer_connection_resumed_test.cc index c274b683c4287..dcebb80ac8c63 100644 --- a/google/cloud/storage/internal/async/writer_connection_resumed_test.cc +++ b/google/cloud/storage/internal/async/writer_connection_resumed_test.cc @@ -11,9 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "google/cloud/storage/internal/async/writer_connection_resumed.h" +#include "google/cloud/mocks/mock_async_streaming_read_write_rpc.h" #include "google/cloud/storage/async/connection.h" +#include "google/cloud/storage/internal/grpc/ctype_cord_workaround.h" #include "google/cloud/storage/mocks/mock_async_writer_connection.h" #include "google/cloud/storage/testing/canonical_errors.h" #include "google/cloud/storage/testing/mock_hash_function.h" @@ -615,6 +616,123 @@ TEST(WriterConnectionResumed, OnQueryUpdatesWriteHandle) { EXPECT_EQ(current_handle->handle(), "updated-handle"); } +TEST(WriterConnectionResumed, ResetWriteOffsetOnResume) { + AsyncSequencer sequencer; + auto mock = std::make_unique(); + auto* mock_ptr = mock.get(); + + auto initial_request = google::storage::v2::BidiWriteObjectRequest{}; + google::storage::v2::BidiWriteObjectResponse first_response; + first_response.mutable_write_handle()->set_handle("initial-handle"); + + auto mock_hash = + std::make_shared(); + EXPECT_CALL(*mock_hash, Update(::testing::An(), + ::testing::An(), + ::testing::An())) + .WillRepeatedly(Return(Status())); + + EXPECT_CALL(*mock_ptr, PersistedState) + .WillOnce(Return(MakePersistedState(0))) + .WillOnce(Return(MakePersistedState(1024))); + + auto const payload = TestPayload(2048); + + EXPECT_CALL(*mock_ptr, Flush(_)).WillOnce([&](auto) { + return sequencer.PushBack("Flush").then([](auto f) { + if (f.get()) return Status{}; + return TransientError(); + }); + }); + + MockFactory mock_factory; + auto mock_stream = + std::make_unique>(); + auto* mock_stream_ptr = mock_stream.get(); + + EXPECT_CALL(mock_factory, Call(_)) + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest const&) { + WriteObject::WriteResult result; + result.stream = std::move(mock_stream); + result.first_response.mutable_write_handle()->set_handle("new-handle"); + return sequencer.PushBack("Factory").then( + [r = std::move(result)](auto) mutable { + return StatusOr(std::move(r)); + }); + }); + + EXPECT_CALL(*mock_stream_ptr, Write(_, _)) + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request, + grpc::WriteOptions) { + EXPECT_EQ(GetContent(request.checksummed_data()).size(), 1024); + EXPECT_EQ(GetContent(request.checksummed_data()), + std::string(1024, 'A')); + return sequencer.PushBack("StreamWrite").then([](auto) { + return true; + }); + }) + .WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request, + grpc::WriteOptions) { + EXPECT_TRUE(GetContent(request.checksummed_data()).empty()); + EXPECT_TRUE(request.flush()); + return sequencer.PushBack("GhostWrite").then([](auto) { return true; }); + }); + + google::storage::v2::BidiWriteObjectResponse read_response1; + read_response1.set_persisted_size(2048); + google::storage::v2::BidiWriteObjectResponse read_response2; + read_response2.set_persisted_size(2048); + EXPECT_CALL(*mock_stream_ptr, Read) + .WillOnce([&, read_response1]() { + return sequencer.PushBack("StreamRead1").then([read_response1](auto) { + return absl::make_optional(read_response1); + }); + }) + .WillOnce([&, read_response2]() { + return sequencer.PushBack("StreamRead2").then([read_response2](auto) { + return absl::make_optional(read_response2); + }); + }); + + EXPECT_CALL(*mock_stream_ptr, Finish) + .WillOnce(Return(make_ready_future(Status{}))); + EXPECT_CALL(*mock_stream_ptr, Cancel).WillRepeatedly(Return()); + + auto connection = MakeWriterConnectionResumed( + mock_factory.AsStdFunction(), std::move(mock), initial_request, mock_hash, + first_response, Options{}); + + auto write = connection->Write(payload); + + auto next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "Flush"); + next.first.set_value(false); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "Factory"); + next.first.set_value(true); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "StreamWrite"); + next.first.set_value(true); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "StreamRead1"); + next.first.set_value(true); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "GhostWrite"); + next.first.set_value(true); + + next = sequencer.PopFrontWithName(); + EXPECT_EQ(next.second, "StreamRead2"); + next.first.set_value(true); + + EXPECT_THAT(write.get(), StatusIs(StatusCode::kOk)); +} + } // namespace GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END } // namespace storage_internal