Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ class AsyncWriterConnectionBufferedState
return tmp;
}

void OnQuery(std::unique_lock<std::mutex> lk, std::int64_t persisted_size) {
void OnQuery(std::unique_lock<std::mutex> lk, std::int64_t persisted_size,
Comment thread
kalragauri marked this conversation as resolved.
bool is_resume = false) {
if (persisted_size < buffer_offset_) {
auto id = UploadId(lk);
return SetError(std::move(lk),
Expand All @@ -297,7 +298,22 @@ class AsyncWriterConnectionBufferedState
}
resend_buffer_.RemovePrefix(static_cast<std::size_t>(n));
buffer_offset_ = persisted_size;
write_offset_ -= static_cast<std::size_t>(n);
if (is_resume) {
// Since the buffer has been modified to start exactly at the point of the
// resume, the next write on this new stream should start from the
// beginning of this truncated buffer.
write_offset_ = 0;
} else {
// While rare, it is possible that n >= write_offset_ (i.e. the server has
// persisted more than we have sent) if, for example, multiple clients
// resume the same upload. If that is the case, all the bytes covered by
// write_offset_ have been flushed and we can reset it to 0.
if (static_cast<std::size_t>(n) >= write_offset_) {
write_offset_ = 0;
} else {
write_offset_ -= static_cast<std::size_t>(n);
}
}
Comment thread
kalragauri marked this conversation as resolved.
// If the buffer is small enough, collect all the handlers to notify them.
auto const handlers = ClearHandlersIfEmpty(lk);
// SetFlushed will release the lock before returning.
Expand Down Expand Up @@ -382,7 +398,7 @@ class AsyncWriterConnectionBufferedState
std::move(state)));
}
// Regular resume succeeded, object not finalized. Continue writing.
OnQuery(std::move(lk), absl::get<std::int64_t>(state));
OnQuery(std::move(lk), absl::get<std::int64_t>(state), /*is_resume=*/true);
}

void SetFinalized(std::unique_lock<std::mutex> lk,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,67 @@ TEST(WriteConnectionBuffered, SetFinalizedIsIdempotent) {
next.first.set_value(true);
}

TEST(WriteConnectionBuffered, ResetWriteOffsetOnResume) {
AsyncSequencer<bool> sequencer;
auto mock = std::make_unique<MockAsyncWriterConnection>();
auto* mock_ptr = mock.get();

EXPECT_CALL(*mock_ptr, UploadId).WillRepeatedly(Return("test-upload-id"));
EXPECT_CALL(*mock_ptr, PersistedState)
.WillOnce(
Return(MakePersistedState(0))); // Initial state: 0 bytes persisted.

EXPECT_CALL(*mock_ptr, Write).WillOnce([&](auto) {
return sequencer.PushBack("Write").then([](auto f) {
if (!f.get()) return TransientError(); // This write will fail.
return Status{};
});
});

MockFactory mock_factory;
auto resumed_mock = std::make_unique<MockAsyncWriterConnection>();
auto* resumed_mock_ptr = resumed_mock.get();

EXPECT_CALL(mock_factory, Call).WillOnce([&]() {
return sequencer.PushBack("Resume").then([&](auto) {
// The resumed connection reports that 1024 bytes have been persisted.
EXPECT_CALL(*resumed_mock_ptr, PersistedState)
.WillRepeatedly(Return(MakePersistedState(1024)));
// We expect the next write on the resumed stream to send the remaining
// 1024 bytes. If the write offset was not reset to 0, this size would be
// incorrect.
EXPECT_CALL(*resumed_mock_ptr, Write).WillOnce([&](auto payload) {
EXPECT_EQ(payload.size(), 1024);
return sequencer.PushBack("ResumedWrite").then([](auto) {
return Status{};
});
});
return make_status_or(std::unique_ptr<storage::AsyncWriterConnection>(
std::move(resumed_mock)));
});
});

auto connection = MakeWriterConnectionBuffered(
mock_factory.AsStdFunction(), std::move(mock), TestOptions());

// Write a total of 2048 bytes.
auto write = connection->Write(TestPayload(2048));

auto next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Write");
next.first.set_value(false);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Resume");
next.first.set_value(true);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "ResumedWrite");
next.first.set_value(true);

EXPECT_STATUS_OK(write.get());
}

} // namespace
GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END
} // namespace storage_internal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,22 @@ class AsyncWriterConnectionResumedState
}
resend_buffer_.RemovePrefix(static_cast<std::size_t>(n));
buffer_offset_ = persisted_size;
write_offset_ -= static_cast<std::size_t>(n);
if (state_ == State::kResuming) {
// Since the buffer has been modified to start exactly at the point of the
// resume, the next write on this new stream should start from the
// beginning of this truncated buffer.
write_offset_ = 0;
} else {
// While rare, it is possible that n >= write_offset_ (i.e. the server has
// persisted more than we have sent) if, for example, multiple clients
// resume the same upload. If that is the case, all the bytes covered by
// write_offset_ have been flushed and we can reset it to 0.
if (static_cast<std::size_t>(n) >= write_offset_) {
write_offset_ = 0;
} else {
write_offset_ -= static_cast<std::size_t>(n);
}
}
// If the buffer is small enough, collect all the handlers to notify them.
auto const handlers = ClearHandlersIfEmpty(lk);
state_ = State::kIdle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "google/cloud/storage/internal/async/writer_connection_resumed.h"
#include "google/cloud/mocks/mock_async_streaming_read_write_rpc.h"
#include "google/cloud/storage/async/connection.h"
#include "google/cloud/storage/internal/grpc/ctype_cord_workaround.h"
#include "google/cloud/storage/mocks/mock_async_writer_connection.h"
#include "google/cloud/storage/testing/canonical_errors.h"
#include "google/cloud/storage/testing/mock_hash_function.h"
Expand Down Expand Up @@ -615,6 +616,123 @@ TEST(WriterConnectionResumed, OnQueryUpdatesWriteHandle) {
EXPECT_EQ(current_handle->handle(), "updated-handle");
}

TEST(WriterConnectionResumed, ResetWriteOffsetOnResume) {
AsyncSequencer<bool> sequencer;
auto mock = std::make_unique<MockAsyncWriterConnection>();
auto* mock_ptr = mock.get();

auto initial_request = google::storage::v2::BidiWriteObjectRequest{};
google::storage::v2::BidiWriteObjectResponse first_response;
first_response.mutable_write_handle()->set_handle("initial-handle");

auto mock_hash =
std::make_shared<google::cloud::storage::testing::MockHashFunction>();
EXPECT_CALL(*mock_hash, Update(::testing::An<std::int64_t>(),
::testing::An<absl::Cord const&>(),
::testing::An<std::uint32_t>()))
.WillRepeatedly(Return(Status()));

EXPECT_CALL(*mock_ptr, PersistedState)
.WillOnce(Return(MakePersistedState(0)))
.WillOnce(Return(MakePersistedState(1024)));

auto const payload = TestPayload(2048);

EXPECT_CALL(*mock_ptr, Flush(_)).WillOnce([&](auto) {
return sequencer.PushBack("Flush").then([](auto f) {
if (f.get()) return Status{};
return TransientError();
});
});

MockFactory mock_factory;
auto mock_stream =
std::make_unique<google::cloud::mocks::MockAsyncStreamingReadWriteRpc<
google::storage::v2::BidiWriteObjectRequest,
google::storage::v2::BidiWriteObjectResponse>>();
auto* mock_stream_ptr = mock_stream.get();

EXPECT_CALL(mock_factory, Call(_))
.WillOnce([&](google::storage::v2::BidiWriteObjectRequest const&) {
WriteObject::WriteResult result;
result.stream = std::move(mock_stream);
result.first_response.mutable_write_handle()->set_handle("new-handle");
return sequencer.PushBack("Factory").then(
[r = std::move(result)](auto) mutable {
return StatusOr<WriteObject::WriteResult>(std::move(r));
});
});

EXPECT_CALL(*mock_stream_ptr, Write(_, _))
.WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request,
grpc::WriteOptions) {
EXPECT_EQ(GetContent(request.checksummed_data()).size(), 1024);
EXPECT_EQ(GetContent(request.checksummed_data()),
std::string(1024, 'A'));
return sequencer.PushBack("StreamWrite").then([](auto) {
return true;
});
})
.WillOnce([&](google::storage::v2::BidiWriteObjectRequest const& request,
grpc::WriteOptions) {
EXPECT_TRUE(GetContent(request.checksummed_data()).empty());
EXPECT_TRUE(request.flush());
return sequencer.PushBack("GhostWrite").then([](auto) { return true; });
});

google::storage::v2::BidiWriteObjectResponse read_response1;
read_response1.set_persisted_size(2048);
google::storage::v2::BidiWriteObjectResponse read_response2;
read_response2.set_persisted_size(2048);
EXPECT_CALL(*mock_stream_ptr, Read)
.WillOnce([&, read_response1]() {
return sequencer.PushBack("StreamRead1").then([read_response1](auto) {
return absl::make_optional(read_response1);
});
})
.WillOnce([&, read_response2]() {
return sequencer.PushBack("StreamRead2").then([read_response2](auto) {
return absl::make_optional(read_response2);
});
});

EXPECT_CALL(*mock_stream_ptr, Finish)
.WillOnce(Return(make_ready_future(Status{})));
EXPECT_CALL(*mock_stream_ptr, Cancel).WillRepeatedly(Return());

auto connection = MakeWriterConnectionResumed(
mock_factory.AsStdFunction(), std::move(mock), initial_request, mock_hash,
first_response, Options{});

auto write = connection->Write(payload);

auto next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Flush");
next.first.set_value(false);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "Factory");
next.first.set_value(true);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "StreamWrite");
next.first.set_value(true);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "StreamRead1");
next.first.set_value(true);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "GhostWrite");
next.first.set_value(true);

next = sequencer.PopFrontWithName();
EXPECT_EQ(next.second, "StreamRead2");
next.first.set_value(true);

EXPECT_THAT(write.get(), StatusIs(StatusCode::kOk));
}

} // namespace
GOOGLE_CLOUD_CPP_INLINE_NAMESPACE_END
} // namespace storage_internal
Expand Down
Loading