Skip to content

Commit b8583e9

Browse files
authored
feat: added ability to select runtime for playground (#3171)
* feat: added ability to select runtime for playground Signed-off-by: Brian <[email protected]> * fix: updated runtime selection filter Signed-off-by: Brian <[email protected]> * fix: adjusted tests with new filter change Signed-off-by: Brian <[email protected]> * fix: removed none and enforced runtime exlusion Signed-off-by: Brian <[email protected]> * chore: edited e2e test to comply with new runtime selection Signed-off-by: Brian <[email protected]> * revert: model preset dependent on runtime it should always be preset Signed-off-by: Brian <[email protected]> * feat: get the inference provider from a registered store Signed-off-by: Brian <[email protected]> * chore: reworked getting registered providers Signed-off-by: Brian <[email protected]> * fix: working on refractoring change to use recommended runtime from config store Signed-off-by: Brian <[email protected]> * fix: use config store to preselect runtime fixed tests Signed-off-by: Brian <[email protected]> * fix: added back openvino selection this will show on mac as well Signed-off-by: Brian <[email protected]> * fix: added back retrieval of registered providers Signed-off-by: Brian <[email protected]> --------- Signed-off-by: Brian <[email protected]>
1 parent c1aadd8 commit b8583e9

File tree

7 files changed

+246
-12
lines changed

7 files changed

+246
-12
lines changed

packages/backend/src/managers/inference/inferenceManager.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
9898
return Array.from(this.#servers.values());
9999
}
100100

101+
/**
102+
* Get the Unique registered Inference provider types
103+
*/
104+
105+
public getRegisteredProviders(): InferenceType[] {
106+
const types: InferenceType[] = this.inferenceProviderRegistry.getAll().map(provider => provider.type);
107+
return [...new Set(types)];
108+
}
109+
101110
/**
102111
* return an inference server
103112
* @param containerId the containerId of the inference server

packages/backend/src/studio-api-impl.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import type { TaskRegistry } from './registries/TaskRegistry';
3030
import type { LocalRepository } from '@shared/models/ILocalRepository';
3131
import type { LocalRepositoryRegistry } from './registries/LocalRepositoryRegistry';
3232
import path from 'node:path';
33-
import type { InferenceServer } from '@shared/models/IInference';
33+
import type { InferenceServer, InferenceType } from '@shared/models/IInference';
3434
import type { CreationInferenceServerOptions } from '@shared/models/InferenceServerConfig';
3535
import type { InferenceManager } from './managers/inference/inferenceManager';
3636
import type { Conversation } from '@shared/models/IPlaygroundMessage';
@@ -144,6 +144,10 @@ export class StudioApiImpl implements StudioAPI {
144144
return this.inferenceManager.getServers();
145145
}
146146

147+
async getRegisteredProviders(): Promise<InferenceType[]> {
148+
return this.inferenceManager.getRegisteredProviders();
149+
}
150+
147151
async requestDeleteInferenceServer(...containerIds: string[]): Promise<void> {
148152
// Do not wait on the promise as the api would probably timeout before the user answer.
149153
if (containerIds.length === 0) throw new Error('At least one container id should be provided.');
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/**********************************************************************
2+
* Copyright (C) 2025 Red Hat, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* SPDX-License-Identifier: Apache-2.0
17+
***********************************************************************/
18+
19+
import '@testing-library/jest-dom/vitest';
20+
import { beforeEach, vi, test, expect } from 'vitest';
21+
import { render, fireEvent, within } from '@testing-library/svelte';
22+
import InferenceRuntimeSelect from '/@/lib/select/InferenceRuntimeSelect.svelte';
23+
import { InferenceType } from '@shared/models/IInference';
24+
25+
const providers: InferenceType[] = [InferenceType.LLAMA_CPP, InferenceType.OPENVINO, InferenceType.WHISPER_CPP];
26+
27+
beforeEach(() => {
28+
// mock scrollIntoView
29+
window.HTMLElement.prototype.scrollIntoView = vi.fn();
30+
});
31+
32+
test('Lists all runtime options', async () => {
33+
const { container } = render(InferenceRuntimeSelect, {
34+
value: undefined,
35+
providers,
36+
disabled: false,
37+
});
38+
39+
const input = within(container).getByLabelText('Select Inference Runtime');
40+
await fireEvent.pointerUp(input);
41+
42+
const items = container.querySelectorAll('div[class~="list-item"]');
43+
const expectedOptions = providers;
44+
45+
expect(items.length).toBe(expectedOptions.length);
46+
47+
expectedOptions.forEach((option, i) => {
48+
expect(items[i]).toHaveTextContent(option);
49+
});
50+
});
51+
52+
test('Selected value should be visible', async () => {
53+
const { container } = render(InferenceRuntimeSelect, {
54+
value: undefined,
55+
providers,
56+
disabled: false,
57+
});
58+
59+
const input = within(container).getByLabelText('Select Inference Runtime');
60+
await fireEvent.pointerUp(input);
61+
62+
const items = container.querySelectorAll('div[class~="list-item"]');
63+
const expectedOptions = providers;
64+
65+
await fireEvent.click(items[0]);
66+
67+
const valueContainer = container.querySelector('.value-container');
68+
if (!(valueContainer instanceof HTMLElement)) throw new Error('Missing value container');
69+
70+
const selectedLabel = within(valueContainer).getByText(expectedOptions[0]);
71+
expect(selectedLabel).toBeDefined();
72+
});
73+
74+
test('Exclude specific runtime from list', async () => {
75+
const excluded = [InferenceType.WHISPER_CPP, InferenceType.OPENVINO];
76+
77+
const { container } = render(InferenceRuntimeSelect, {
78+
value: undefined,
79+
providers,
80+
disabled: false,
81+
exclude: excluded,
82+
});
83+
84+
const input = within(container).getByLabelText('Select Inference Runtime');
85+
await fireEvent.pointerUp(input);
86+
87+
const items = container.querySelectorAll('div[class~="list-item"]');
88+
const itemTexts = Array.from(items).map(item => item.textContent?.trim());
89+
90+
excluded.forEach(excludedType => {
91+
expect(itemTexts).not.toContain(excludedType);
92+
});
93+
94+
const expected = providers.filter(type => !excluded.includes(type));
95+
96+
expected.forEach(included => {
97+
expect(itemTexts).toContain(included);
98+
});
99+
});
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<script lang="ts">
2+
import Select from '/@/lib/select/Select.svelte';
3+
import type { InferenceType } from '@shared/models/IInference';
4+
5+
interface Props {
6+
disabled?: boolean;
7+
value: InferenceType | undefined;
8+
providers: InferenceType[];
9+
exclude?: InferenceType[];
10+
}
11+
let { value = $bindable(), disabled, providers, exclude = [] }: Props = $props();
12+
13+
// Filter options based on optional exclude list
14+
const options = $derived(() =>
15+
providers.filter(type => !exclude.includes(type)).map(type => ({ value: type, label: type })),
16+
);
17+
18+
function handleOnChange(nValue: { value: string } | undefined): void {
19+
if (nValue) {
20+
value = nValue.value as InferenceType;
21+
} else {
22+
value = undefined;
23+
}
24+
}
25+
</script>
26+
27+
<Select
28+
label="Select Inference Runtime"
29+
name="select-inference-runtime"
30+
disabled={disabled}
31+
value={value ? { label: value, value: value } : undefined}
32+
onchange={handleOnChange}
33+
placeholder="Select Inference Runtime to use"
34+
items={options()} />

packages/frontend/src/pages/PlaygroundCreate.spec.ts

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,24 @@ const dummyWhisperCppModel: ModelInfo = {
5555
backend: InferenceType.WHISPER_CPP,
5656
};
5757

58+
const dummyOpenVinoModel: ModelInfo = {
59+
id: 'openvino-model-id',
60+
name: 'Dummy Openvino model',
61+
file: {
62+
file: 'file',
63+
path: path.resolve(os.tmpdir(), 'path'),
64+
},
65+
properties: {},
66+
description: '',
67+
backend: InferenceType.OPENVINO,
68+
};
69+
5870
vi.mock('../utils/client', async () => {
5971
return {
6072
studioClient: {
6173
requestCreatePlayground: vi.fn(),
6274
getExtensionConfiguration: vi.fn().mockResolvedValue({}),
75+
getRegisteredProviders: vi.fn().mockResolvedValue([]),
6376
},
6477
rpcBrowser: {
6578
subscribe: (): unknown => {
@@ -88,28 +101,58 @@ beforeEach(() => {
88101

89102
const tasksList = writable<Task[]>([]);
90103
vi.mocked(tasksStore).tasks = tasksList;
104+
vi.mocked(studioClient.getRegisteredProviders).mockResolvedValue([
105+
InferenceType.LLAMA_CPP,
106+
InferenceType.WHISPER_CPP,
107+
InferenceType.OPENVINO,
108+
]);
91109
});
92110

93-
test('model should be selected by default', () => {
111+
test('model should be selected by default when runtime is set', async () => {
94112
const modelsInfoList = writable<ModelInfo[]>([dummyLlamaCppModel]);
95113
vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList;
96114

97115
vi.mocked(studioClient.requestCreatePlayground).mockRejectedValue('error creating playground');
98116

99-
const { container } = render(PlaygroundCreate);
117+
const { container } = render(PlaygroundCreate, { props: { exclude: [InferenceType.NONE] } });
118+
119+
// Select our runtime
120+
const dropdown = within(container).getByLabelText('Select Inference Runtime');
121+
await userEvent.click(dropdown);
122+
123+
const llamacppOption = within(container).getByText(InferenceType.LLAMA_CPP);
124+
await userEvent.click(llamacppOption);
100125

101126
const model = within(container).getByText(dummyLlamaCppModel.name);
102127
expect(model).toBeInTheDocument();
103128
});
104129

105-
test('models with incompatible backend should not be listed', async () => {
106-
const modelsInfoList = writable<ModelInfo[]>([dummyWhisperCppModel]);
130+
test('selecting a runtime filters the displayed models', async () => {
131+
const modelsInfoList = writable<ModelInfo[]>([dummyLlamaCppModel, dummyWhisperCppModel, dummyOpenVinoModel]);
132+
vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList;
133+
134+
const { container } = render(PlaygroundCreate, { props: { exclude: [InferenceType.NONE] } });
135+
136+
// Select our runtime
137+
const dropdown = within(container).getByLabelText('Select Inference Runtime');
138+
await userEvent.click(dropdown);
139+
140+
const openvinoOption = within(container).getByText(InferenceType.OPENVINO);
141+
await userEvent.click(openvinoOption);
142+
143+
expect(within(container).queryByText(dummyOpenVinoModel.name)).toBeInTheDocument();
144+
expect(within(container).queryByText(dummyLlamaCppModel.name)).toBeNull();
145+
expect(within(container).queryByText(dummyWhisperCppModel.name)).toBeNull();
146+
});
147+
148+
test('should show warning when no local models are available', () => {
149+
const modelsInfoList = writable<ModelInfo[]>([]);
107150
vi.mocked(modelsInfoStore).modelsInfo = modelsInfoList;
108151

109152
const { container } = render(PlaygroundCreate);
110153

111-
const model = within(container).queryByText(dummyWhisperCppModel.name);
112-
expect(model).toBeNull();
154+
const warning = within(container).getByText(/You don't have any models downloaded/);
155+
expect(warning).toBeInTheDocument();
113156
});
114157

115158
test('should display error message if createPlayground fails', async () => {
@@ -123,6 +166,13 @@ test('should display error message if createPlayground fails', async () => {
123166
const errorMessage = within(container).queryByLabelText('Error Message Content');
124167
expect(errorMessage).not.toBeInTheDocument();
125168

169+
// Select the runtime first
170+
const runtimeDropdown = within(container).getByLabelText('Select Inference Runtime');
171+
await userEvent.click(runtimeDropdown);
172+
173+
const runtimeOption = within(container).getByText(InferenceType.LLAMA_CPP);
174+
await userEvent.click(runtimeOption);
175+
126176
const createButton = within(container).getByTitle('Create playground');
127177
await userEvent.click(createButton);
128178

packages/frontend/src/pages/PlaygroundCreate.svelte

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,34 @@ import type { Unsubscriber } from 'svelte/store';
1414
import { Button, ErrorMessage, FormPage, Input } from '@podman-desktop/ui-svelte';
1515
import ModelSelect from '/@/lib/select/ModelSelect.svelte';
1616
import { InferenceType } from '@shared/models/IInference';
17+
import InferenceRuntimeSelect from '/@/lib/select/InferenceRuntimeSelect.svelte';
18+
import { configuration } from '../stores/extensionConfiguration';
19+
20+
// Get recommended runtime
21+
let runtime: InferenceType | undefined = undefined;
22+
23+
// Exlude certain runtimes from selection
24+
export let exclude: InferenceType[] = [InferenceType.NONE, InferenceType.WHISPER_CPP];
25+
26+
// Get registered list of providers
27+
let providers: InferenceType[] = [];
28+
29+
onMount(async () => {
30+
providers = await studioClient.getRegisteredProviders();
31+
32+
const inferenceRuntime = $configuration?.inferenceRuntime;
33+
if (
34+
Object.values(InferenceType).includes(inferenceRuntime as InferenceType) &&
35+
!exclude.includes(inferenceRuntime as InferenceType)
36+
) {
37+
runtime = inferenceRuntime as InferenceType;
38+
}
39+
});
1740
1841
let localModels: ModelInfo[];
19-
$: localModels = $modelsInfo.filter(model => model.file && model.backend !== InferenceType.WHISPER_CPP);
42+
$: localModels = $modelsInfo.filter(
43+
model => model.file && (!runtime || model.backend === runtime) && !exclude.includes(model.backend as InferenceType),
44+
);
2045
$: availModels = $modelsInfo.filter(model => !model.file);
2146
let model: ModelInfo | undefined = undefined;
2247
let submitted: boolean = false;
@@ -30,10 +55,11 @@ let trackingId: string | undefined = undefined;
3055
// The trackedTasks are the tasks linked to the trackingId
3156
let trackedTasks: Task[] = [];
3257
33-
$: {
34-
if (!model && localModels.length > 0) {
35-
model = localModels[0];
36-
}
58+
// Preset model selection depending on runtime
59+
$: if (localModels.length > 0) {
60+
model = localModels[0];
61+
} else {
62+
model = undefined;
3763
}
3864
3965
function openModelsPage(): void {
@@ -145,6 +171,12 @@ export function goToUpPage(): void {
145171
placeholder="Leave blank to generate a name"
146172
aria-label="playgroundName" />
147173

174+
<!-- inference runtime -->
175+
<label for="inference-runtime" class="pt-4 block mb-2 font-bold text-[var(--pd-content-card-header-text)]">
176+
Inference Runtime
177+
</label>
178+
<InferenceRuntimeSelect bind:value={runtime} providers={providers} exclude={exclude} />
179+
148180
<!-- model input -->
149181
<label for="model" class="pt-4 block mb-2 font-bold text-[var(--pd-content-card-header-text)]">Model</label>
150182
<ModelSelect models={localModels} disabled={submitted} bind:value={model} />

packages/shared/src/StudioAPI.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
***********************************************************************/
1818

1919
import type { ModelInfo } from './models/IModelInfo';
20+
import type { InferenceType } from '@shared/models/IInference';
2021
import type { ApplicationCatalog } from './models/IApplicationCatalog';
2122
import type { OpenDialogOptions, Uri } from '@podman-desktop/api';
2223
import type { ApplicationState } from './models/IApplicationState';
@@ -121,6 +122,11 @@ export interface StudioAPI {
121122
*/
122123
getInferenceServers(): Promise<InferenceServer[]>;
123124

125+
/**
126+
* Get inference providers
127+
*/
128+
getRegisteredProviders(): Promise<InferenceType[]>;
129+
124130
/**
125131
* Request to start an inference server
126132
* @param options The options to use

0 commit comments

Comments
 (0)