Skip to content

Commit 3a66593

Browse files
authored
Merge pull request #150 from VectorInstitute/bugfix/broken-throughput
Fix broken throughput computation
2 parents 8442168 + 5e1ce48 commit 3a66593

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

.github/workflows/code_checks.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ jobs:
4949
uses: pypa/[email protected]
5050
with:
5151
virtual-environment: .venv/
52+
# Temporary: ignore pip advisory until fixed in pip>=25.3
53+
ignore-vulns: GHSA-4xh5-x5gv-qwph

vec_inf/cli/_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def metrics(slurm_job_id: str) -> None:
447447
metrics_formatter.format_metrics()
448448

449449
live.update(metrics_formatter.table)
450-
time.sleep(2)
450+
time.sleep(1)
451451
except click.ClickException as e:
452452
raise e
453453
except Exception as e:

vec_inf/client/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class VecInfClient:
8181

8282
def __init__(self) -> None:
8383
"""Initialize the Vector Inference client."""
84-
pass
84+
self._metrics_collectors: dict[str, PerformanceMetricsCollector] = {}
8585

8686
def list_models(self) -> list[ModelInfo]:
8787
"""List all available models.
@@ -218,7 +218,13 @@ def get_metrics(self, slurm_job_id: str) -> MetricsResponse:
218218
- Performance metrics or error message
219219
- Timestamp of collection
220220
"""
221-
performance_metrics_collector = PerformanceMetricsCollector(slurm_job_id)
221+
# Use cached collector to preserve state between calls to compute throughput
222+
if slurm_job_id not in self._metrics_collectors:
223+
self._metrics_collectors[slurm_job_id] = PerformanceMetricsCollector(
224+
slurm_job_id
225+
)
226+
227+
performance_metrics_collector = self._metrics_collectors[slurm_job_id]
222228

223229
metrics: Union[dict[str, float], str]
224230
if not performance_metrics_collector.metrics_url.startswith("http"):

0 commit comments

Comments
 (0)