Address review follow-ups for parent-main inheritance opt-out

Clean up mutable defaults, give parent-main bootstrap data a named type, and add direct start_actor coverage so the opt-out change is clearer to review.
subint_spawner_backend
mahmoud 2026-04-06 22:32:50 +00:00 committed by mahmoudhas
parent ea971d25aa
commit 00637764d9
4 changed files with 93 additions and 19 deletions

View File

@ -204,6 +204,53 @@ def test_loglevel_propagated_to_subactor(
assert 'yoyoyo' in captured.err assert 'yoyoyo' in captured.err
def test_run_in_actor_can_skip_parent_main_inheritance(
start_method,
reg_addr,
monkeypatch,
):
if start_method != 'trio':
pytest.skip(
'parent main inheritance opt-out only affects the trio spawn backend'
)
from tractor.spawn import _mp_fixup_main
monkeypatch.setattr(
_mp_fixup_main,
'_mp_figure_out_main',
lambda inherit_parent_main=True: (
{'init_main_from_name': __name__}
if inherit_parent_main
else {}
),
)
async def main():
async with tractor.open_nursery(
name='registrar',
start_method=start_method,
registry_addrs=[reg_addr],
) as an:
replaying = await an.run_in_actor(
get_main_mod_name,
name='replaying-parent-main',
)
isolated = await an.run_in_actor(
get_main_mod_name,
name='isolated-parent-main',
inherit_parent_main=False,
)
# Stdlib spawn re-runs an importable parent ``__main__`` as
# ``__mp_main__``; opting out should leave the child bootstrap
# module alone instead.
# https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods
assert await replaying.result() == '__mp_main__'
assert await isolated.result() == '__main__'
trio.run(main)
def test_start_actor_can_skip_parent_main_inheritance( def test_start_actor_can_skip_parent_main_inheritance(
start_method, start_method,
reg_addr, reg_addr,
@ -225,23 +272,32 @@ def test_start_actor_can_skip_parent_main_inheritance(
), ),
) )
async def main() -> None: async def main():
async with tractor.open_nursery( async with tractor.open_nursery(
name='registrar', name='registrar',
start_method=start_method, start_method=start_method,
registry_addrs=[reg_addr], registry_addrs=[reg_addr],
) as an: ) as an:
replaying = await an.run_in_actor( replaying = await an.start_actor(
get_main_mod_name, 'replaying-parent-main',
name='replaying-parent-main', enable_modules=[__name__],
) )
isolated = await an.run_in_actor( isolated = await an.start_actor(
get_main_mod_name, 'isolated-parent-main',
name='isolated-parent-main', enable_modules=[__name__],
inherit_parent_main=False, inherit_parent_main=False,
) )
try:
assert await replaying.result() == '__mp_main__' assert await replaying.run_from_ns(
assert await isolated.result() == '__main__' __name__,
'get_main_mod_name',
) == '__mp_main__'
assert await isolated.run_from_ns(
__name__,
'get_main_mod_name',
) == '__main__'
finally:
await replaying.cancel_actor()
await isolated.cancel_actor()
trio.run(main) trio.run(main)

View File

@ -119,6 +119,7 @@ from ..discovery._discovery import get_registry
from ._portal import Portal from ._portal import Portal
from . import _state from . import _state
from ..spawn import _mp_fixup_main from ..spawn import _mp_fixup_main
from ..spawn._mp_fixup_main import ParentMainData
from . import _rpc from . import _rpc
if TYPE_CHECKING: if TYPE_CHECKING:
@ -218,7 +219,7 @@ class Actor:
return self._ipc_server return self._ipc_server
# Information about `__main__` from parent # Information about `__main__` from parent
_parent_main_data: dict[str, str] _parent_main_data: ParentMainData
_parent_chan_cs: CancelScope|None = None _parent_chan_cs: CancelScope|None = None
_spawn_spec: msgtypes.SpawnSpec|None = None _spawn_spec: msgtypes.SpawnSpec|None = None
@ -240,7 +241,7 @@ class Actor:
name: str, name: str,
uuid: str, uuid: str,
*, *,
enable_modules: list[str] = [], enable_modules: list[str] | None = None,
loglevel: str|None = None, loglevel: str|None = None,
registry_addrs: list[Address]|None = None, registry_addrs: list[Address]|None = None,
spawn_method: str|None = None, spawn_method: str|None = None,
@ -268,12 +269,13 @@ class Actor:
# retrieve and store parent `__main__` data which # retrieve and store parent `__main__` data which
# will be passed to children # will be passed to children
self._parent_main_data = _mp_fixup_main._mp_figure_out_main( self._parent_main_data: ParentMainData = _mp_fixup_main._mp_figure_out_main(
inherit_parent_main, inherit_parent_main=inherit_parent_main,
) )
# TODO? only add this when `is_debug_mode() == True` no? # TODO? only add this when `is_debug_mode() == True` no?
# always include debugging tools module # always include debugging tools module
enable_modules = list(enable_modules or [])
if _state.is_root_process(): if _state.is_root_process():
enable_modules.append('tractor.devx.debug._tty_lock') enable_modules.append('tractor.devx.debug._tty_lock')

View File

@ -200,7 +200,7 @@ class ActorNursery:
# a `._ria_nursery` since the dependent APIs have been # a `._ria_nursery` since the dependent APIs have been
# removed! # removed!
nursery: trio.Nursery|None = None, nursery: trio.Nursery|None = None,
proc_kwargs: dict[str, any] = {} proc_kwargs: dict[str, typing.Any] | None = None,
) -> Portal: ) -> Portal:
''' '''
@ -229,7 +229,8 @@ class ActorNursery:
_rtv['_debug_mode'] = debug_mode _rtv['_debug_mode'] = debug_mode
self._at_least_one_child_in_debug = True self._at_least_one_child_in_debug = True
enable_modules = enable_modules or [] enable_modules = list(enable_modules or [])
proc_kwargs = dict(proc_kwargs or {})
if rpc_module_paths: if rpc_module_paths:
warnings.warn( warnings.warn(
@ -296,7 +297,7 @@ class ActorNursery:
loglevel: str | None = None, # set log level per subactor loglevel: str | None = None, # set log level per subactor
infect_asyncio: bool = False, infect_asyncio: bool = False,
inherit_parent_main: bool = True, inherit_parent_main: bool = True,
proc_kwargs: dict[str, any] = {}, proc_kwargs: dict[str, typing.Any] | None = None,
**kwargs, # explicit args to ``fn`` **kwargs, # explicit args to ``fn``
@ -317,6 +318,7 @@ class ActorNursery:
# use the explicit function name if not provided # use the explicit function name if not provided
name = fn.__name__ name = fn.__name__
proc_kwargs = dict(proc_kwargs or {})
portal: Portal = await self.start_actor( portal: Portal = await self.start_actor(
name, name,
enable_modules=[mod_path] + ( enable_modules=[mod_path] + (

View File

@ -22,20 +22,34 @@ These helpers are needed for any spawing backend that doesn't already
handle this. For example when using ``trio_run_in_process`` it is needed handle this. For example when using ``trio_run_in_process`` it is needed
but obviously not when we're already using ``multiprocessing``. but obviously not when we're already using ``multiprocessing``.
These helpers mirror the stdlib spawn/forkserver bootstrap that rebuilds
the parent's `__main__` in a fresh child interpreter. In particular, we
capture enough info to later replay the parent's main module as
`__mp_main__` (or by path) in the child process.
See:
https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods
""" """
import os import os
import sys import sys
import platform import platform
import types import types
import runpy import runpy
from typing import NotRequired
from typing import TypedDict
ORIGINAL_DIR = os.path.abspath(os.getcwd()) ORIGINAL_DIR = os.path.abspath(os.getcwd())
class ParentMainData(TypedDict):
init_main_from_name: NotRequired[str]
init_main_from_path: NotRequired[str]
def _mp_figure_out_main( def _mp_figure_out_main(
inherit_parent_main: bool = True, inherit_parent_main: bool = True,
) -> dict[str, str]: ) -> ParentMainData:
"""Taken from ``multiprocessing.spawn.get_preparation_data()``. """Taken from ``multiprocessing.spawn.get_preparation_data()``.
Retrieve parent actor `__main__` module data. Retrieve parent actor `__main__` module data.
@ -43,7 +57,7 @@ def _mp_figure_out_main(
if not inherit_parent_main: if not inherit_parent_main:
return {} return {}
d = {} d: ParentMainData = {}
# Figure out whether to initialise main in the subprocess as a module # Figure out whether to initialise main in the subprocess as a module
# or through direct execution (or to leave it alone entirely) # or through direct execution (or to leave it alone entirely)
main_module = sys.modules['__main__'] main_module = sys.modules['__main__']