Skip to content

Commit 75f3e7c

Browse files
committed
Refactored so overriden logic lives in register_activity() so async works regardless of how we register it...
Signed-off-by: Patrick Assuied <[email protected]>
1 parent d976959 commit 75f3e7c

File tree

1 file changed

+48
-28
lines changed

1 file changed

+48
-28
lines changed

ext/dapr-ext-workflow/dapr/ext/workflow/aio/workflow_runtime.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,44 @@ def __init__(
6161
self._main_event_loop = main_event_loop
6262
self._logger = Logger('WorkflowRuntime', logger_options)
6363

64+
def register_activity(self, fn: Activity, *, name: Optional[str] = None):
65+
"""Registers an async workflow activity and ensures proper execution and metadata.
66+
This mirrors the decorator behavior so direct registration behaves identically.
67+
"""
68+
# Validate/prepare alternate name on the original function first (not the wrapper)
69+
if hasattr(fn, '_activity_registered'):
70+
alt_name = fn.__dict__['_dapr_alternate_name']
71+
raise ValueError(f'Activity {fn.__name__} already registered as {alt_name}')
72+
if hasattr(fn, '_dapr_alternate_name'):
73+
alt_name = fn._dapr_alternate_name
74+
if name is not None:
75+
raise ValueError(f'Activity {fn.__name__} already has an alternate name {alt_name}')
76+
else:
77+
alt_name = name if name else fn.__name__
78+
fn.__dict__['_dapr_alternate_name'] = alt_name
79+
80+
# Build the target function that super().register_activity will wrap.
81+
if self._main_event_loop:
82+
83+
@wraps(fn)
84+
def target_fn(*args, **kwargs):
85+
result = fn(*args, **kwargs)
86+
if inspect.isawaitable(result):
87+
future = asyncio.run_coroutine_threadsafe(result, self._main_event_loop)
88+
return future.result()
89+
return result
90+
else:
91+
92+
@wraps(fn)
93+
def target_fn():
94+
return fn
95+
96+
# Delegate to base registration without passing name because the wrapper already
97+
# carries _dapr_alternate_name via @wraps(fn) copying fn.__dict__.
98+
super().register_activity(target_fn, name=None)
99+
fn.__dict__['_activity_registered'] = True
100+
fn.__dict__['_dapr_alternate_name'] = alt_name
101+
64102
def activity(self, __fn: Activity = None, *, name: Optional[str] = None):
65103
"""Decorator to register an async activity function.
66104
@@ -88,38 +126,20 @@ async def add(ctx, x: int, y: int) -> int:
88126
"""
89127

90128
def wrapper(fn: Activity):
91-
# If a main event loop is provided, wrap the function so that any awaitable
92-
# result is executed on that loop in a thread-safe way.
93-
if self._main_event_loop:
94-
95-
@wraps(fn)
96-
def sync_wrapper(*args, **kwargs):
97-
result = fn(*args, **kwargs)
98-
if inspect.isawaitable(result):
99-
future = asyncio.run_coroutine_threadsafe(result, self._main_event_loop)
100-
return future.result()
101-
return result
102-
103-
target_fn = sync_wrapper
104-
else:
105-
# No special handling needed; register the original function directly.
106-
@wraps(fn)
107-
def innerfn():
108-
return fn
109-
110-
target_fn = innerfn
129+
# Register first to ensure original fn has flags set
130+
self.register_activity(fn, name=name)
111131

112-
self.register_activity(target_fn, name=name)
132+
@wraps(fn)
133+
def innerfn():
134+
return fn
113135

136+
# Mirror naming metadata on the returned decorator wrapper
114137
if hasattr(fn, '_dapr_alternate_name'):
115-
target_fn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name']
138+
innerfn.__dict__['_dapr_alternate_name'] = fn.__dict__['_dapr_alternate_name']
116139
else:
117-
target_fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__
118-
target_fn.__signature__ = inspect.signature(fn)
119-
# Copy attributes to fn so it doesn't get registered again when calling `register_activity()` again.
120-
fn.__dict__['_activity_registered'] = target_fn.__dict__['_activity_registered']
121-
fn.__dict__['_dapr_alternate_name'] = target_fn.__dict__['_dapr_alternate_name']
122-
return target_fn
140+
innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__
141+
innerfn.__signature__ = inspect.signature(fn)
142+
return innerfn
123143

124144
if __fn:
125145
# This case is true when the decorator is used without arguments

0 commit comments

Comments
 (0)