Skip to content

Commit 858ab9c

Browse files
authored
Revert "ai: Auto select user model when there's no default" (#36932)
Reverts #36722 Release Notes: - N/A
1 parent 2c64b05 commit 858ab9c

File tree

9 files changed

+122
-184
lines changed

9 files changed

+122
-184
lines changed

crates/agent/src/thread.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ impl Thread {
664664
}
665665

666666
pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
667-
if self.configured_model.is_none() || self.messages.is_empty() {
667+
if self.configured_model.is_none() {
668668
self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
669669
}
670670
self.configured_model.clone()
@@ -2097,7 +2097,7 @@ impl Thread {
20972097
}
20982098

20992099
pub fn summarize(&mut self, cx: &mut Context<Self>) {
2100-
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model(cx) else {
2100+
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
21012101
println!("No thread summary model");
21022102
return;
21032103
};
@@ -2416,7 +2416,7 @@ impl Thread {
24162416
}
24172417

24182418
let Some(ConfiguredModel { model, provider }) =
2419-
LanguageModelRegistry::read_global(cx).thread_summary_model(cx)
2419+
LanguageModelRegistry::read_global(cx).thread_summary_model()
24202420
else {
24212421
return;
24222422
};
@@ -5410,10 +5410,13 @@ fn main() {{
54105410
}),
54115411
cx,
54125412
);
5413-
registry.set_thread_summary_model(Some(ConfiguredModel {
5414-
provider,
5415-
model: model.clone(),
5416-
}));
5413+
registry.set_thread_summary_model(
5414+
Some(ConfiguredModel {
5415+
provider,
5416+
model: model.clone(),
5417+
}),
5418+
cx,
5419+
);
54175420
})
54185421
});
54195422

crates/agent2/src/agent.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ impl NativeAgent {
228228
) -> Entity<AcpThread> {
229229
let connection = Rc::new(NativeAgentConnection(cx.entity()));
230230
let registry = LanguageModelRegistry::read_global(cx);
231-
let summarization_model = registry.thread_summary_model(cx).map(|c| c.model);
231+
let summarization_model = registry.thread_summary_model().map(|c| c.model);
232232

233233
thread_handle.update(cx, |thread, cx| {
234234
thread.set_summarization_model(summarization_model, cx);
@@ -524,7 +524,7 @@ impl NativeAgent {
524524

525525
let registry = LanguageModelRegistry::read_global(cx);
526526
let default_model = registry.default_model().map(|m| m.model);
527-
let summarization_model = registry.thread_summary_model(cx).map(|m| m.model);
527+
let summarization_model = registry.thread_summary_model().map(|m| m.model);
528528

529529
for session in self.sessions.values_mut() {
530530
session.thread.update(cx, |thread, cx| {

crates/agent2/src/tests/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,11 +1822,11 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
18221822
let clock = Arc::new(clock::FakeSystemClock::new());
18231823
let client = Client::new(clock, http_client, cx);
18241824
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1825-
Project::init_settings(cx);
1826-
agent_settings::init(cx);
18271825
language_model::init(client.clone(), cx);
18281826
language_models::init(user_store, client.clone(), cx);
1827+
Project::init_settings(cx);
18291828
LanguageModelRegistry::test(cx);
1829+
agent_settings::init(cx);
18301830
});
18311831
cx.executor().forbid_parking();
18321832

crates/agent_ui/src/language_model_selector.rs

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use feature_flags::ZedProFeatureFlag;
66
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
77
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
88
use language_model::{
9-
ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
9+
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
10+
LanguageModelRegistry,
1011
};
1112
use ordered_float::OrderedFloat;
1213
use picker::{Picker, PickerDelegate};
@@ -76,6 +77,7 @@ pub struct LanguageModelPickerDelegate {
7677
all_models: Arc<GroupedModels>,
7778
filtered_entries: Vec<LanguageModelPickerEntry>,
7879
selected_index: usize,
80+
_authenticate_all_providers_task: Task<()>,
7981
_subscriptions: Vec<Subscription>,
8082
}
8183

@@ -96,6 +98,7 @@ impl LanguageModelPickerDelegate {
9698
selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
9799
filtered_entries: entries,
98100
get_active_model: Arc::new(get_active_model),
101+
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
99102
_subscriptions: vec![cx.subscribe_in(
100103
&LanguageModelRegistry::global(cx),
101104
window,
@@ -139,6 +142,56 @@ impl LanguageModelPickerDelegate {
139142
.unwrap_or(0)
140143
}
141144

145+
/// Authenticates all providers in the [`LanguageModelRegistry`].
146+
///
147+
/// We do this so that we can populate the language selector with all of the
148+
/// models from the configured providers.
149+
fn authenticate_all_providers(cx: &mut App) -> Task<()> {
150+
let authenticate_all_providers = LanguageModelRegistry::global(cx)
151+
.read(cx)
152+
.providers()
153+
.iter()
154+
.map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
155+
.collect::<Vec<_>>();
156+
157+
cx.spawn(async move |_cx| {
158+
for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
159+
if let Err(err) = authenticate_task.await {
160+
if matches!(err, AuthenticateError::CredentialsNotFound) {
161+
// Since we're authenticating these providers in the
162+
// background for the purposes of populating the
163+
// language selector, we don't care about providers
164+
// where the credentials are not found.
165+
} else {
166+
// Some providers have noisy failure states that we
167+
// don't want to spam the logs with every time the
168+
// language model selector is initialized.
169+
//
170+
// Ideally these should have more clear failure modes
171+
// that we know are safe to ignore here, like what we do
172+
// with `CredentialsNotFound` above.
173+
match provider_id.0.as_ref() {
174+
"lmstudio" | "ollama" => {
175+
// LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
176+
//
177+
// These fail noisily, so we don't log them.
178+
}
179+
"copilot_chat" => {
180+
// Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
181+
}
182+
_ => {
183+
log::error!(
184+
"Failed to authenticate provider: {}: {err}",
185+
provider_name.0
186+
);
187+
}
188+
}
189+
}
190+
}
191+
}
192+
})
193+
}
194+
142195
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
143196
(self.get_active_model)(cx)
144197
}

crates/git_ui/src/git_panel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4466,7 +4466,7 @@ fn current_language_model(cx: &Context<'_, GitPanel>) -> Option<Arc<dyn Language
44664466
is_enabled
44674467
.then(|| {
44684468
let ConfiguredModel { provider, model } =
4469-
LanguageModelRegistry::read_global(cx).commit_message_model(cx)?;
4469+
LanguageModelRegistry::read_global(cx).commit_message_model()?;
44704470

44714471
provider.is_authenticated(cx).then(|| model)
44724472
})

crates/language_model/src/registry.rs

Lines changed: 47 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use collections::BTreeMap;
66
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
77
use std::{str::FromStr, sync::Arc};
88
use thiserror::Error;
9+
use util::maybe;
910

1011
pub fn init(cx: &mut App) {
1112
let registry = cx.new(|_cx| LanguageModelRegistry::default());
@@ -41,9 +42,7 @@ impl std::fmt::Debug for ConfigurationError {
4142
#[derive(Default)]
4243
pub struct LanguageModelRegistry {
4344
default_model: Option<ConfiguredModel>,
44-
/// This model is automatically configured by a user's environment after
45-
/// authenticating all providers. It's only used when default_model is not available.
46-
environment_fallback_model: Option<ConfiguredModel>,
45+
default_fast_model: Option<ConfiguredModel>,
4746
inline_assistant_model: Option<ConfiguredModel>,
4847
commit_message_model: Option<ConfiguredModel>,
4948
thread_summary_model: Option<ConfiguredModel>,
@@ -99,6 +98,9 @@ impl ConfiguredModel {
9998

10099
pub enum Event {
101100
DefaultModelChanged,
101+
InlineAssistantModelChanged,
102+
CommitMessageModelChanged,
103+
ThreadSummaryModelChanged,
102104
ProviderStateChanged(LanguageModelProviderId),
103105
AddedProvider(LanguageModelProviderId),
104106
RemovedProvider(LanguageModelProviderId),
@@ -224,7 +226,7 @@ impl LanguageModelRegistry {
224226
cx: &mut Context<Self>,
225227
) {
226228
let configured_model = model.and_then(|model| self.select_model(model, cx));
227-
self.set_inline_assistant_model(configured_model);
229+
self.set_inline_assistant_model(configured_model, cx);
228230
}
229231

230232
pub fn select_commit_message_model(
@@ -233,7 +235,7 @@ impl LanguageModelRegistry {
233235
cx: &mut Context<Self>,
234236
) {
235237
let configured_model = model.and_then(|model| self.select_model(model, cx));
236-
self.set_commit_message_model(configured_model);
238+
self.set_commit_message_model(configured_model, cx);
237239
}
238240

239241
pub fn select_thread_summary_model(
@@ -242,7 +244,7 @@ impl LanguageModelRegistry {
242244
cx: &mut Context<Self>,
243245
) {
244246
let configured_model = model.and_then(|model| self.select_model(model, cx));
245-
self.set_thread_summary_model(configured_model);
247+
self.set_thread_summary_model(configured_model, cx);
246248
}
247249

248250
/// Selects and sets the inline alternatives for language models based on
@@ -276,60 +278,68 @@ impl LanguageModelRegistry {
276278
}
277279

278280
pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
279-
match (self.default_model(), model.as_ref()) {
281+
match (self.default_model.as_ref(), model.as_ref()) {
280282
(Some(old), Some(new)) if old.is_same_as(new) => {}
281283
(None, None) => {}
282284
_ => cx.emit(Event::DefaultModelChanged),
283285
}
286+
self.default_fast_model = maybe!({
287+
let provider = &model.as_ref()?.provider;
288+
let fast_model = provider.default_fast_model(cx)?;
289+
Some(ConfiguredModel {
290+
provider: provider.clone(),
291+
model: fast_model,
292+
})
293+
});
284294
self.default_model = model;
285295
}
286296

287-
pub fn set_environment_fallback_model(
297+
pub fn set_inline_assistant_model(
288298
&mut self,
289299
model: Option<ConfiguredModel>,
290300
cx: &mut Context<Self>,
291301
) {
292-
if self.default_model.is_none() {
293-
match (self.environment_fallback_model.as_ref(), model.as_ref()) {
294-
(Some(old), Some(new)) if old.is_same_as(new) => {}
295-
(None, None) => {}
296-
_ => cx.emit(Event::DefaultModelChanged),
297-
}
302+
match (self.inline_assistant_model.as_ref(), model.as_ref()) {
303+
(Some(old), Some(new)) if old.is_same_as(new) => {}
304+
(None, None) => {}
305+
_ => cx.emit(Event::InlineAssistantModelChanged),
298306
}
299-
self.environment_fallback_model = model;
300-
}
301-
302-
pub fn set_inline_assistant_model(&mut self, model: Option<ConfiguredModel>) {
303307
self.inline_assistant_model = model;
304308
}
305309

306-
pub fn set_commit_message_model(&mut self, model: Option<ConfiguredModel>) {
310+
pub fn set_commit_message_model(
311+
&mut self,
312+
model: Option<ConfiguredModel>,
313+
cx: &mut Context<Self>,
314+
) {
315+
match (self.commit_message_model.as_ref(), model.as_ref()) {
316+
(Some(old), Some(new)) if old.is_same_as(new) => {}
317+
(None, None) => {}
318+
_ => cx.emit(Event::CommitMessageModelChanged),
319+
}
307320
self.commit_message_model = model;
308321
}
309322

310-
pub fn set_thread_summary_model(&mut self, model: Option<ConfiguredModel>) {
323+
pub fn set_thread_summary_model(
324+
&mut self,
325+
model: Option<ConfiguredModel>,
326+
cx: &mut Context<Self>,
327+
) {
328+
match (self.thread_summary_model.as_ref(), model.as_ref()) {
329+
(Some(old), Some(new)) if old.is_same_as(new) => {}
330+
(None, None) => {}
331+
_ => cx.emit(Event::ThreadSummaryModelChanged),
332+
}
311333
self.thread_summary_model = model;
312334
}
313335

314-
#[track_caller]
315336
pub fn default_model(&self) -> Option<ConfiguredModel> {
316337
#[cfg(debug_assertions)]
317338
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
318339
return None;
319340
}
320341

321-
self.default_model
322-
.clone()
323-
.or_else(|| self.environment_fallback_model.clone())
324-
}
325-
326-
pub fn default_fast_model(&self, cx: &App) -> Option<ConfiguredModel> {
327-
let provider = self.default_model()?.provider;
328-
let fast_model = provider.default_fast_model(cx)?;
329-
Some(ConfiguredModel {
330-
provider,
331-
model: fast_model,
332-
})
342+
self.default_model.clone()
333343
}
334344

335345
pub fn inline_assistant_model(&self) -> Option<ConfiguredModel> {
@@ -343,27 +353,27 @@ impl LanguageModelRegistry {
343353
.or_else(|| self.default_model.clone())
344354
}
345355

346-
pub fn commit_message_model(&self, cx: &App) -> Option<ConfiguredModel> {
356+
pub fn commit_message_model(&self) -> Option<ConfiguredModel> {
347357
#[cfg(debug_assertions)]
348358
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
349359
return None;
350360
}
351361

352362
self.commit_message_model
353363
.clone()
354-
.or_else(|| self.default_fast_model(cx))
364+
.or_else(|| self.default_fast_model.clone())
355365
.or_else(|| self.default_model.clone())
356366
}
357367

358-
pub fn thread_summary_model(&self, cx: &App) -> Option<ConfiguredModel> {
368+
pub fn thread_summary_model(&self) -> Option<ConfiguredModel> {
359369
#[cfg(debug_assertions)]
360370
if std::env::var("ZED_SIMULATE_NO_LLM_PROVIDER").is_ok() {
361371
return None;
362372
}
363373

364374
self.thread_summary_model
365375
.clone()
366-
.or_else(|| self.default_fast_model(cx))
376+
.or_else(|| self.default_fast_model.clone())
367377
.or_else(|| self.default_model.clone())
368378
}
369379

@@ -400,34 +410,4 @@ mod tests {
400410
let providers = registry.read(cx).providers();
401411
assert!(providers.is_empty());
402412
}
403-
404-
#[gpui::test]
405-
async fn test_configure_environment_fallback_model(cx: &mut gpui::TestAppContext) {
406-
let registry = cx.new(|_| LanguageModelRegistry::default());
407-
408-
let provider = FakeLanguageModelProvider::default();
409-
registry.update(cx, |registry, cx| {
410-
registry.register_provider(provider.clone(), cx);
411-
});
412-
413-
cx.update(|cx| provider.authenticate(cx)).await.unwrap();
414-
415-
registry.update(cx, |registry, cx| {
416-
let provider = registry.provider(&provider.id()).unwrap();
417-
418-
registry.set_environment_fallback_model(
419-
Some(ConfiguredModel {
420-
provider: provider.clone(),
421-
model: provider.default_model(cx).unwrap(),
422-
}),
423-
cx,
424-
);
425-
426-
let default_model = registry.default_model().unwrap();
427-
let fallback_model = registry.environment_fallback_model.clone().unwrap();
428-
429-
assert_eq!(default_model.model.id(), fallback_model.model.id());
430-
assert_eq!(default_model.provider.id(), fallback_model.provider.id());
431-
});
432-
}
433413
}

crates/language_models/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ ollama = { workspace = true, features = ["schemars"] }
4444
open_ai = { workspace = true, features = ["schemars"] }
4545
open_router = { workspace = true, features = ["schemars"] }
4646
partial-json-fixer.workspace = true
47-
project.workspace = true
4847
release_channel.workspace = true
4948
schemars.workspace = true
5049
serde.workspace = true

0 commit comments

Comments
 (0)