Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions nvflare/app_opt/tracking/mlflow/mlflow_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsData, AnalyticsDataType, LogWriterName, TrackConst
from nvflare.apis.dxo import from_shareable
from nvflare.apis.fl_constant import ProcessType
from nvflare.apis.fl_constant import ProcessType, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import JobMetaKey
from nvflare.apis.shareable import Shareable
from nvflare.app_common.widgets.streaming import AnalyticsReceiver

Expand All @@ -41,6 +42,14 @@ def get_current_time_millis():
return int(round(time.time() * 1000))


def _get_job_name_from_fl_ctx(fl_ctx: FLContext, default=None):
# TODO: it might be good to have a function in fl_context to get the job name
job_meta = fl_ctx.get_prop(ReservedKey.JOB_META)
if job_meta and isinstance(job_meta, dict):
return job_meta.get(JobMetaKey.JOB_NAME, default)
return default


class MLflowReceiver(AnalyticsReceiver):
def __init__(
self,
Expand Down Expand Up @@ -80,7 +89,6 @@ def __init__(

self.kw_args = kw_args if kw_args else {}
self.tracking_uri = tracking_uri
self.mlflow = mlflow
self.mlflow_clients: Dict[str, MlflowClient] = {}
self.experiment_id = None
self.run_ids = {}
Expand All @@ -89,8 +97,14 @@ def __init__(
self.time_since_flush = 0
self.buff_flush_time = buffer_flush_time

def _get_tracking_uri(self, fl_ctx: FLContext):
if self.tracking_uri:
mlflow.set_tracking_uri(uri=self.tracking_uri)
return self.tracking_uri

workspace = fl_ctx.get_workspace()
job_id = fl_ctx.get_job_id()
mlflow_root = os.path.abspath(os.path.join(workspace.get_result_root(job_id), "mlflow"))
return f"file://{mlflow_root}"

def initialize(self, fl_ctx: FLContext):
"""Initializes MlflowClient for each site.
Expand All @@ -112,6 +126,8 @@ def initialize(self, fl_ctx: FLContext):
# Initialize context and timing
self.time_start = 0

mlflow.set_tracking_uri(uri=self._get_tracking_uri(fl_ctx))

# Validate and prepare experiment configuration
art_full_path = self._get_artifact_location(self.artifact_location, fl_ctx)
experiment_name = self.kw_args.get(TrackConst.EXPERIMENT_NAME, "FLARE FL Experiment")
Expand Down Expand Up @@ -164,8 +180,9 @@ def _mlflow_setup(self, art_full_path, experiment_name, experiment_tags, site_na
)

job_id_tag = self._get_job_id_tag(fl_ctx)
job_name = _get_job_name_from_fl_ctx(fl_ctx)

run_name = self._get_run_name(self.kw_args, site_name, job_id_tag)
run_name = self._get_run_name(self.kw_args, site_name, job_id_tag, job_name)
tags = self._get_run_tags(self.kw_args, job_id_tag, run_name)
run = mlflow_client.create_run(experiment_id=self.experiment_id, run_name=run_name, tags=tags)
self.run_ids[site_name] = run.info.run_id
Expand All @@ -179,9 +196,9 @@ def _init_buffer(self, site_names: List[str]):
AnalyticsDataType.TAGS: [],
}

def _get_run_name(self, kwargs: dict, site_name: str, job_id_tag: str):
def _get_run_name(self, kwargs: dict, site_name: str, job_id_tag: str, job_name: str):
run_name = kwargs.get(TrackConst.RUN_NAME, DEFAULT_RUN_NAME)
return f"{site_name}-{job_id_tag[:6]}-{run_name}"
return f"{site_name}-{job_id_tag[:6]}-{job_name}-{run_name}"

def _get_run_tags(self, kwargs, job_id_tag: str, run_name: str):
run_tags = self._get_tags(TrackConst.RUN_TAGS, kwargs=kwargs)
Expand Down
7 changes: 3 additions & 4 deletions nvflare/app_opt/tracking/tb/tb_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None):
.. code-block:: text
:caption: Folder structure

Inside run_XX folder:
- workspace
- run_01 (already created):
- site_workspace
- job_id1 (already created):
- output_dir (default: tb_events):
- peer_name_1:
- peer_name_2:

- run_02 (already created):
- job_id2 (already created):
- output_dir (default: tb_events):
- peer_name_1:
- peer_name_2:
Expand Down
Loading