Skip to content

Commit a6c76b7

Browse files
committed
fix: added back retrieval of registered providers
Signed-off-by: Brian <[email protected]>
1 parent cb557c8 commit a6c76b7

File tree

7 files changed

+48
-16
lines changed

7 files changed

+48
-16
lines changed

packages/backend/src/managers/inference/inferenceManager.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
9898
return Array.from(this.#servers.values());
9999
}
100100

101+
/**
102+
* Get the Unique registered Inference provider types
103+
*/
104+
105+
public getRegisteredProviders(): InferenceType[] {
106+
const types: InferenceType[] = this.inferenceProviderRegistry.getAll().map(provider => provider.type);
107+
return [...new Set(types)];
108+
}
109+
101110
/**
102111
* return an inference server
103112
* @param containerId the containerId of the inference server

packages/backend/src/studio-api-impl.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import type { TaskRegistry } from './registries/TaskRegistry';
3030
import type { LocalRepository } from '@shared/models/ILocalRepository';
3131
import type { LocalRepositoryRegistry } from './registries/LocalRepositoryRegistry';
3232
import path from 'node:path';
33-
import type { InferenceServer } from '@shared/models/IInference';
33+
import type { InferenceServer, InferenceType } from '@shared/models/IInference';
3434
import type { CreationInferenceServerOptions } from '@shared/models/InferenceServerConfig';
3535
import type { InferenceManager } from './managers/inference/inferenceManager';
3636
import type { Conversation } from '@shared/models/IPlaygroundMessage';
@@ -144,6 +144,10 @@ export class StudioApiImpl implements StudioAPI {
144144
return this.inferenceManager.getServers();
145145
}
146146

147+
async getRegisteredProviders(): Promise<InferenceType[]> {
148+
return this.inferenceManager.getRegisteredProviders();
149+
}
150+
147151
async requestDeleteInferenceServer(...containerIds: string[]): Promise<void> {
148152
// Do not wait on the promise as the api would probably timeout before the user answer.
149153
if (containerIds.length === 0) throw new Error('At least one container id should be provided.');

packages/frontend/src/lib/select/InferenceRuntimeSelect.spec.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ import { render, fireEvent, within } from '@testing-library/svelte';
2222
import InferenceRuntimeSelect from '/@/lib/select/InferenceRuntimeSelect.svelte';
2323
import { InferenceType } from '@shared/models/IInference';
2424

25-
const getFilteredOptions = (exclude: InferenceType[] = []): InferenceType[] =>
26-
Object.values(InferenceType).filter(type => !exclude.includes(type));
25+
const providers: InferenceType[] = [InferenceType.LLAMA_CPP, InferenceType.OPENVINO, InferenceType.WHISPER_CPP];
2726

2827
beforeEach(() => {
2928
// mock scrollIntoView
@@ -33,14 +32,15 @@ beforeEach(() => {
3332
test('Lists all runtime options', async () => {
3433
const { container } = render(InferenceRuntimeSelect, {
3534
value: undefined,
35+
providers,
3636
disabled: false,
3737
});
3838

3939
const input = within(container).getByLabelText('Select Inference Runtime');
4040
await fireEvent.pointerUp(input);
4141

4242
const items = container.querySelectorAll('div[class~="list-item"]');
43-
const expectedOptions = getFilteredOptions();
43+
const expectedOptions = providers;
4444

4545
expect(items.length).toBe(expectedOptions.length);
4646

@@ -52,14 +52,15 @@ test('Lists all runtime options', async () => {
5252
test('Selected value should be visible', async () => {
5353
const { container } = render(InferenceRuntimeSelect, {
5454
value: undefined,
55+
providers,
5556
disabled: false,
5657
});
5758

5859
const input = within(container).getByLabelText('Select Inference Runtime');
5960
await fireEvent.pointerUp(input);
6061

6162
const items = container.querySelectorAll('div[class~="list-item"]');
62-
const expectedOptions = getFilteredOptions();
63+
const expectedOptions = providers;
6364

6465
await fireEvent.click(items[0]);
6566

@@ -75,6 +76,7 @@ test('Exclude specific runtime from list', async () => {
7576

7677
const { container } = render(InferenceRuntimeSelect, {
7778
value: undefined,
79+
providers,
7880
disabled: false,
7981
exclude: excluded,
8082
});
@@ -89,7 +91,7 @@ test('Exclude specific runtime from list', async () => {
8991
expect(itemTexts).not.toContain(excludedType);
9092
});
9193

92-
const expected = getFilteredOptions(excluded);
94+
const expected = providers.filter(type => !excluded.includes(type));
9395

9496
expected.forEach(included => {
9597
expect(itemTexts).toContain(included);

packages/frontend/src/lib/select/InferenceRuntimeSelect.svelte

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
<script lang="ts">
22
import Select from '/@/lib/select/Select.svelte';
3-
import { InferenceType } from '@shared/models/IInference';
3+
import type { InferenceType } from '@shared/models/IInference';
44
55
interface Props {
66
disabled?: boolean;
77
value: InferenceType | undefined;
8+
providers: InferenceType[];
89
exclude?: InferenceType[];
910
}
11+
let { value = $bindable(), disabled, providers, exclude = [] }: Props = $props();
1012
11-
let { value = $bindable(), disabled, exclude = [] }: Props = $props();
12-
13-
const options = Object.values(InferenceType).filter(type => !exclude.includes(type));
13+
// Filter options based on optional exclude list
14+
const options = $derived(() =>
15+
providers.filter(type => !exclude.includes(type)).map(type => ({ value: type, label: type })),
16+
);
1417
1518
function handleOnChange(nValue: { value: string } | undefined): void {
1619
if (nValue) {
@@ -28,7 +31,4 @@ function handleOnChange(nValue: { value: string } | undefined): void {
2831
value={value ? { label: value, value: value } : undefined}
2932
onchange={handleOnChange}
3033
placeholder="Select Inference Runtime to use"
31-
items={options.map(type => ({
32-
value: type,
33-
label: type,
34-
}))} />
34+
items={options()} />

packages/frontend/src/pages/PlaygroundCreate.spec.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ vi.mock('../utils/client', async () => {
7272
studioClient: {
7373
requestCreatePlayground: vi.fn(),
7474
getExtensionConfiguration: vi.fn().mockResolvedValue({}),
75+
getRegisteredProviders: vi.fn().mockResolvedValue([]),
7576
},
7677
rpcBrowser: {
7778
subscribe: (): unknown => {
@@ -100,6 +101,11 @@ beforeEach(() => {
100101

101102
const tasksList = writable<Task[]>([]);
102103
vi.mocked(tasksStore).tasks = tasksList;
104+
vi.mocked(studioClient.getRegisteredProviders).mockResolvedValue([
105+
InferenceType.LLAMA_CPP,
106+
InferenceType.WHISPER_CPP,
107+
InferenceType.OPENVINO,
108+
]);
103109
});
104110

105111
test('model should be selected by default when runtime is set', async () => {

packages/frontend/src/pages/PlaygroundCreate.svelte

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ let runtime: InferenceType | undefined = undefined;
2323
// Exlude certain runtimes from selection
2424
export let exclude: InferenceType[] = [InferenceType.NONE, InferenceType.WHISPER_CPP];
2525
26-
onMount(() => {
26+
// Get registered list of providers
27+
let providers: InferenceType[] = [];
28+
29+
onMount(async () => {
30+
providers = await studioClient.getRegisteredProviders();
31+
2732
const inferenceRuntime = $configuration?.inferenceRuntime;
2833
if (
2934
Object.values(InferenceType).includes(inferenceRuntime as InferenceType) &&
@@ -170,7 +175,7 @@ export function goToUpPage(): void {
170175
<label for="inference-runtime" class="pt-4 block mb-2 font-bold text-[var(--pd-content-card-header-text)]">
171176
Inference Runtime
172177
</label>
173-
<InferenceRuntimeSelect bind:value={runtime} exclude={exclude} />
178+
<InferenceRuntimeSelect bind:value={runtime} providers={providers} exclude={exclude} />
174179

175180
<!-- model input -->
176181
<label for="model" class="pt-4 block mb-2 font-bold text-[var(--pd-content-card-header-text)]">Model</label>

packages/shared/src/StudioAPI.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
***********************************************************************/
1818

1919
import type { ModelInfo } from './models/IModelInfo';
20+
import type { InferenceType } from '@shared/models/IInference';
2021
import type { ApplicationCatalog } from './models/IApplicationCatalog';
2122
import type { OpenDialogOptions, Uri } from '@podman-desktop/api';
2223
import type { ApplicationState } from './models/IApplicationState';
@@ -121,6 +122,11 @@ export interface StudioAPI {
121122
*/
122123
getInferenceServers(): Promise<InferenceServer[]>;
123124

125+
/**
126+
* Get inference providers
127+
*/
128+
getRegisteredProviders(): Promise<InferenceType[]>;
129+
124130
/**
125131
* Request to start an inference server
126132
* @param options The options to use

0 commit comments

Comments
 (0)