Skip to content

Commit bff1d69

Browse files
authored
Fixing memoryview error (#2929)
* Fixed dup seq 0 bug * Formatting errors
1 parent c278aff commit bff1d69

File tree

2 files changed

+38
-31
lines changed

2 files changed

+38
-31
lines changed

nvflare/fuel/f3/streaming/blob_streamer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(self, future: StreamFuture, stream: Stream):
7777
else:
7878
self.buffer = FastBuffer()
7979

80+
def __str__(self):
81+
return f"Blob[SID:{self.future.get_stream_id()} Size:{self.size}]"
82+
8083

8184
class BlobHandler:
8285
def __init__(self, blob_cb: Callable):
@@ -113,23 +116,22 @@ def _read_stream(blob_task: BlobTask):
113116
if blob_task.pre_allocated:
114117
remaining = len(blob_task.buffer) - buf_size
115118
if length > remaining:
116-
log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}")
119+
log.error(f"{blob_task} Buffer overrun: {remaining=} {length=} {buf_size=}")
117120
if remaining > 0:
118121
blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining]
122+
break
119123
else:
120124
blob_task.buffer[buf_size : buf_size + length] = buf
121125
else:
122126
blob_task.buffer.append(buf)
123127
except Exception as ex:
124-
log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
128+
log.error(f"{blob_task} memoryview error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
125129
raise ex
126130

127131
buf_size += length
128132

129133
if blob_task.size and blob_task.size != buf_size:
130-
log.warning(
131-
f"Stream {blob_task.future.get_stream_id()} size doesn't match: " f"{blob_task.size} <> {buf_size}"
132-
)
134+
log.warning(f"Stream {blob_task} Size doesn't match: " f"{blob_task.size} <> {buf_size}")
133135

134136
if blob_task.pre_allocated:
135137
result = blob_task.buffer
@@ -138,7 +140,7 @@ def _read_stream(blob_task: BlobTask):
138140

139141
blob_task.future.set_result(result)
140142
except Exception as ex:
141-
log.error(f"Stream {blob_task.future.get_stream_id()} read error: {ex}")
143+
log.error(f"Stream {blob_task} Read error: {ex}")
142144
log.error(secure_format_traceback())
143145
blob_task.future.set_exception(ex)
144146

nvflare/fuel/f3/streaming/byte_receiver.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self, sid: int, origin: str):
7171
self.last_chunk_received = False
7272

7373
def __str__(self):
74-
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic}]"
74+
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"
7575

7676

7777
class RxStream(Stream):
@@ -98,9 +98,7 @@ def read(self, chunk_size: int) -> bytes:
9898

9999
# Block if buffers are empty
100100
if count > 0:
101-
log.warning(f"Read block is unblocked multiple times: {count}")
102-
103-
self.task.waiter.clear()
101+
log.warning(f"{self.task} Read block is unblocked multiple times: {count}")
104102

105103
if not self.task.waiter.wait(self.timeout):
106104
error = StreamError(f"{self.task} read timed out after {self.timeout} seconds")
@@ -117,6 +115,7 @@ def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]:
117115
if self.task.eos:
118116
return RESULT_EOS, None
119117
else:
118+
self.task.waiter.clear()
120119
return RESULT_WAIT, None
121120

122121
last_chunk, buf = self.task.buffers.popleft()
@@ -239,33 +238,39 @@ def _data_handler(self, message: Message):
239238
self.stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False)
240239
return
241240

242-
if seq == 0:
243-
# Handle new stream
244-
task.channel = message.get_header(StreamHeaderKey.CHANNEL)
245-
task.topic = message.get_header(StreamHeaderKey.TOPIC)
246-
task.headers = message.headers
241+
with task.task_lock:
242+
if seq == 0:
243+
# Handle new stream
244+
task.channel = message.get_header(StreamHeaderKey.CHANNEL)
245+
task.topic = message.get_header(StreamHeaderKey.TOPIC)
246+
task.headers = message.headers
247+
248+
# GRPC may re-send the same request, causing seq 0 delivered more than once
249+
if task.stream_future:
250+
log.warning(f"{task} Received duplicate chunk 0, ignored")
251+
return
247252

248-
task.stream_future = StreamFuture(sid, message.headers)
249-
task.size = message.get_header(StreamHeaderKey.SIZE, 0)
250-
task.stream_future.set_size(task.size)
253+
task.stream_future = StreamFuture(sid, message.headers)
254+
task.size = message.get_header(StreamHeaderKey.SIZE, 0)
255+
task.stream_future.set_size(task.size)
251256

252-
# Invoke callback
253-
callback = self.registry.find(task.channel, task.topic)
254-
if not callback:
255-
self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}"))
256-
return
257+
# Invoke callback
258+
callback = self.registry.find(task.channel, task.topic)
259+
if not callback:
260+
self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}"))
261+
return
257262

258-
self.received_stream_counter_pool.increment(
259-
category=stream_stats_category(task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED
260-
)
263+
self.received_stream_counter_pool.increment(
264+
category=stream_stats_category(task.channel, task.topic, "stream"),
265+
counter_name=COUNTER_NAME_RECEIVED,
266+
)
261267

262-
self.received_stream_size_pool.record_value(
263-
category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB
264-
)
268+
self.received_stream_size_pool.record_value(
269+
category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB
270+
)
265271

266-
stream_thread_pool.submit(self._callback_wrapper, task, callback)
272+
stream_thread_pool.submit(self._callback_wrapper, task, callback)
267273

268-
with task.task_lock:
269274
data_type = message.get_header(StreamHeaderKey.DATA_TYPE)
270275
last_chunk = data_type == StreamDataType.FINAL
271276
if last_chunk:

0 commit comments

Comments
 (0)