Skip to content

Commit 8271301

Browse files
jaycee-licopybara-github
authored andcommitted
fix: Lightning trainer fails to be unwrapped in remote training
PiperOrigin-RevId: 566779800
1 parent a83105e commit 8271301

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

vertexai/preview/_workflow/driver/__init__.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -197,18 +197,33 @@ def _unwrapper(instance: Any) -> Callable[..., Any]:
197197

198198
config_map = dict()
199199

200-
for attr_name, attr_value in inspect.getmembers(instance):
201-
if isinstance(attr_value, VertexRemoteFunctor):
202-
config_map[attr_name] = (
203-
attr_value.vertex,
204-
attr_value._remote_executor,
205-
attr_value._remote_executor_kwargs,
206-
)
207-
setattr(instance, attr_name, attr_value._method)
208-
209200
if not wrapped_in_place:
201+
for (
202+
attr_name,
203+
attr_value,
204+
remote_executor,
205+
remote_executor_kwargs,
206+
) in _supported_member_iter(instance):
207+
if isinstance(attr_value, VertexRemoteFunctor):
208+
config_map[attr_name] = (
209+
attr_value.vertex,
210+
remote_executor,
211+
remote_executor_kwargs,
212+
)
213+
setattr(instance, attr_name, attr_value._method)
214+
210215
instance.__class__ = super_class
211216

217+
else:
218+
for attr_name, attr_value in inspect.getmembers(instance):
219+
if isinstance(attr_value, VertexRemoteFunctor):
220+
config_map[attr_name] = (
221+
attr_value.vertex,
222+
attr_value._remote_executor,
223+
attr_value._remote_executor_kwargs,
224+
)
225+
setattr(instance, attr_name, attr_value._method)
226+
212227
return functools.partial(
213228
_rewrapper, wrapped_class=current_class, config_map=config_map
214229
)

0 commit comments

Comments
 (0)