@@ -6,6 +6,7 @@ use collections::BTreeMap;
6
6
use gpui:: { App , Context , Entity , EventEmitter , Global , prelude:: * } ;
7
7
use std:: { str:: FromStr , sync:: Arc } ;
8
8
use thiserror:: Error ;
9
+ use util:: maybe;
9
10
10
11
pub fn init ( cx : & mut App ) {
11
12
let registry = cx. new ( |_cx| LanguageModelRegistry :: default ( ) ) ;
@@ -41,9 +42,7 @@ impl std::fmt::Debug for ConfigurationError {
41
42
#[ derive( Default ) ]
42
43
pub struct LanguageModelRegistry {
43
44
default_model : Option < ConfiguredModel > ,
44
- /// This model is automatically configured by a user's environment after
45
- /// authenticating all providers. It's only used when default_model is not available.
46
- environment_fallback_model : Option < ConfiguredModel > ,
45
+ default_fast_model : Option < ConfiguredModel > ,
47
46
inline_assistant_model : Option < ConfiguredModel > ,
48
47
commit_message_model : Option < ConfiguredModel > ,
49
48
thread_summary_model : Option < ConfiguredModel > ,
@@ -99,6 +98,9 @@ impl ConfiguredModel {
99
98
100
99
pub enum Event {
101
100
DefaultModelChanged ,
101
+ InlineAssistantModelChanged ,
102
+ CommitMessageModelChanged ,
103
+ ThreadSummaryModelChanged ,
102
104
ProviderStateChanged ( LanguageModelProviderId ) ,
103
105
AddedProvider ( LanguageModelProviderId ) ,
104
106
RemovedProvider ( LanguageModelProviderId ) ,
@@ -224,7 +226,7 @@ impl LanguageModelRegistry {
224
226
cx : & mut Context < Self > ,
225
227
) {
226
228
let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
227
- self . set_inline_assistant_model ( configured_model) ;
229
+ self . set_inline_assistant_model ( configured_model, cx ) ;
228
230
}
229
231
230
232
pub fn select_commit_message_model (
@@ -233,7 +235,7 @@ impl LanguageModelRegistry {
233
235
cx : & mut Context < Self > ,
234
236
) {
235
237
let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
236
- self . set_commit_message_model ( configured_model) ;
238
+ self . set_commit_message_model ( configured_model, cx ) ;
237
239
}
238
240
239
241
pub fn select_thread_summary_model (
@@ -242,7 +244,7 @@ impl LanguageModelRegistry {
242
244
cx : & mut Context < Self > ,
243
245
) {
244
246
let configured_model = model. and_then ( |model| self . select_model ( model, cx) ) ;
245
- self . set_thread_summary_model ( configured_model) ;
247
+ self . set_thread_summary_model ( configured_model, cx ) ;
246
248
}
247
249
248
250
/// Selects and sets the inline alternatives for language models based on
@@ -276,60 +278,68 @@ impl LanguageModelRegistry {
276
278
}
277
279
278
280
pub fn set_default_model ( & mut self , model : Option < ConfiguredModel > , cx : & mut Context < Self > ) {
279
- match ( self . default_model ( ) , model. as_ref ( ) ) {
281
+ match ( self . default_model . as_ref ( ) , model. as_ref ( ) ) {
280
282
( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
281
283
( None , None ) => { }
282
284
_ => cx. emit ( Event :: DefaultModelChanged ) ,
283
285
}
286
+ self . default_fast_model = maybe ! ( {
287
+ let provider = & model. as_ref( ) ?. provider;
288
+ let fast_model = provider. default_fast_model( cx) ?;
289
+ Some ( ConfiguredModel {
290
+ provider: provider. clone( ) ,
291
+ model: fast_model,
292
+ } )
293
+ } ) ;
284
294
self . default_model = model;
285
295
}
286
296
287
- pub fn set_environment_fallback_model (
297
+ pub fn set_inline_assistant_model (
288
298
& mut self ,
289
299
model : Option < ConfiguredModel > ,
290
300
cx : & mut Context < Self > ,
291
301
) {
292
- if self . default_model . is_none ( ) {
293
- match ( self . environment_fallback_model . as_ref ( ) , model. as_ref ( ) ) {
294
- ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
295
- ( None , None ) => { }
296
- _ => cx. emit ( Event :: DefaultModelChanged ) ,
297
- }
302
+ match ( self . inline_assistant_model . as_ref ( ) , model. as_ref ( ) ) {
303
+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
304
+ ( None , None ) => { }
305
+ _ => cx. emit ( Event :: InlineAssistantModelChanged ) ,
298
306
}
299
- self . environment_fallback_model = model;
300
- }
301
-
302
- pub fn set_inline_assistant_model ( & mut self , model : Option < ConfiguredModel > ) {
303
307
self . inline_assistant_model = model;
304
308
}
305
309
306
- pub fn set_commit_message_model ( & mut self , model : Option < ConfiguredModel > ) {
310
+ pub fn set_commit_message_model (
311
+ & mut self ,
312
+ model : Option < ConfiguredModel > ,
313
+ cx : & mut Context < Self > ,
314
+ ) {
315
+ match ( self . commit_message_model . as_ref ( ) , model. as_ref ( ) ) {
316
+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
317
+ ( None , None ) => { }
318
+ _ => cx. emit ( Event :: CommitMessageModelChanged ) ,
319
+ }
307
320
self . commit_message_model = model;
308
321
}
309
322
310
- pub fn set_thread_summary_model ( & mut self , model : Option < ConfiguredModel > ) {
323
+ pub fn set_thread_summary_model (
324
+ & mut self ,
325
+ model : Option < ConfiguredModel > ,
326
+ cx : & mut Context < Self > ,
327
+ ) {
328
+ match ( self . thread_summary_model . as_ref ( ) , model. as_ref ( ) ) {
329
+ ( Some ( old) , Some ( new) ) if old. is_same_as ( new) => { }
330
+ ( None , None ) => { }
331
+ _ => cx. emit ( Event :: ThreadSummaryModelChanged ) ,
332
+ }
311
333
self . thread_summary_model = model;
312
334
}
313
335
314
- #[ track_caller]
315
336
pub fn default_model ( & self ) -> Option < ConfiguredModel > {
316
337
#[ cfg( debug_assertions) ]
317
338
if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
318
339
return None ;
319
340
}
320
341
321
- self . default_model
322
- . clone ( )
323
- . or_else ( || self . environment_fallback_model . clone ( ) )
324
- }
325
-
326
- pub fn default_fast_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
327
- let provider = self . default_model ( ) ?. provider ;
328
- let fast_model = provider. default_fast_model ( cx) ?;
329
- Some ( ConfiguredModel {
330
- provider,
331
- model : fast_model,
332
- } )
342
+ self . default_model . clone ( )
333
343
}
334
344
335
345
pub fn inline_assistant_model ( & self ) -> Option < ConfiguredModel > {
@@ -343,27 +353,27 @@ impl LanguageModelRegistry {
343
353
. or_else ( || self . default_model . clone ( ) )
344
354
}
345
355
346
- pub fn commit_message_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
356
+ pub fn commit_message_model ( & self ) -> Option < ConfiguredModel > {
347
357
#[ cfg( debug_assertions) ]
348
358
if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
349
359
return None ;
350
360
}
351
361
352
362
self . commit_message_model
353
363
. clone ( )
354
- . or_else ( || self . default_fast_model ( cx ) )
364
+ . or_else ( || self . default_fast_model . clone ( ) )
355
365
. or_else ( || self . default_model . clone ( ) )
356
366
}
357
367
358
- pub fn thread_summary_model ( & self , cx : & App ) -> Option < ConfiguredModel > {
368
+ pub fn thread_summary_model ( & self ) -> Option < ConfiguredModel > {
359
369
#[ cfg( debug_assertions) ]
360
370
if std:: env:: var ( "ZED_SIMULATE_NO_LLM_PROVIDER" ) . is_ok ( ) {
361
371
return None ;
362
372
}
363
373
364
374
self . thread_summary_model
365
375
. clone ( )
366
- . or_else ( || self . default_fast_model ( cx ) )
376
+ . or_else ( || self . default_fast_model . clone ( ) )
367
377
. or_else ( || self . default_model . clone ( ) )
368
378
}
369
379
@@ -400,34 +410,4 @@ mod tests {
400
410
let providers = registry. read ( cx) . providers ( ) ;
401
411
assert ! ( providers. is_empty( ) ) ;
402
412
}
403
-
404
- #[ gpui:: test]
405
- async fn test_configure_environment_fallback_model ( cx : & mut gpui:: TestAppContext ) {
406
- let registry = cx. new ( |_| LanguageModelRegistry :: default ( ) ) ;
407
-
408
- let provider = FakeLanguageModelProvider :: default ( ) ;
409
- registry. update ( cx, |registry, cx| {
410
- registry. register_provider ( provider. clone ( ) , cx) ;
411
- } ) ;
412
-
413
- cx. update ( |cx| provider. authenticate ( cx) ) . await . unwrap ( ) ;
414
-
415
- registry. update ( cx, |registry, cx| {
416
- let provider = registry. provider ( & provider. id ( ) ) . unwrap ( ) ;
417
-
418
- registry. set_environment_fallback_model (
419
- Some ( ConfiguredModel {
420
- provider : provider. clone ( ) ,
421
- model : provider. default_model ( cx) . unwrap ( ) ,
422
- } ) ,
423
- cx,
424
- ) ;
425
-
426
- let default_model = registry. default_model ( ) . unwrap ( ) ;
427
- let fallback_model = registry. environment_fallback_model . clone ( ) . unwrap ( ) ;
428
-
429
- assert_eq ! ( default_model. model. id( ) , fallback_model. model. id( ) ) ;
430
- assert_eq ! ( default_model. provider. id( ) , fallback_model. provider. id( ) ) ;
431
- } ) ;
432
- }
433
413
}
0 commit comments