@@ -90,39 +90,51 @@ static void sigint_handler(int signo) {
9090
9191class chat_formatter {
9292public:
93- chat_formatter (common_params & params, std::vector<common_chat_msg> & chat_msgs, struct common_chat_templates * chat_templates)
93+
94+ struct result {
95+ std::string formatted;
96+ bool tool_was_called;
97+ };
98+
99+ chat_formatter (common_params & params,
100+ std::vector<common_chat_msg> & chat_msgs,
101+ struct common_chat_templates * chat_templates)
102+
94103 : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {}
95104
96105#ifdef LLAMA_USE_TOOLCALL
97106 chat_formatter (common_params & params,
98107 std::vector<common_chat_msg> & chat_msgs,
99108 struct common_chat_templates * chat_templates,
100109 const llama_vocab * vocab,
101- toolcall::client::ptr tc_client,
102- common_chat_format * chat_format)
110+ toolcall::client::ptr tc_client)
103111
104- : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {}
112+ : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates),
113+ vocab_(vocab), tc_client_(tc_client),
114+ chat_format_(COMMON_CHAT_FORMAT_CONTENT_ONLY),
115+ formatted_() {}
105116#endif
106117
107- std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false ) {
118+ chat_formatter::result operator () (const std::string & role, const std::string & content) {
119+
120+ common_chat_msg new_msg = common_chat_parse (content, chat_format_);
121+ new_msg.role = role;
108122
109123 common_chat_templates_inputs cinputs;
110124 cinputs.use_jinja = params_.use_jinja ;
111125 cinputs.add_generation_prompt = (role == " user" );
112126#ifdef LLAMA_USE_TOOLCALL
113- if (tc_client_ != nullptr && use_toolcalls ) {
127+ if (tc_client_ != nullptr ) {
114128 cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat (tc_client_->tool_choice ());
115129 cinputs.tools = common_chat_tools_parse_oaicompat (tc_client_->tool_list ());
116130 }
117131#endif
118- for (const auto & msg : chat_msgs_) {
119- cinputs.messages .push_back (common_chat_msg (msg));
120- }
121-
122- common_chat_msg new_msg = common_chat_parse (content, *chat_format_);
123- new_msg.role = role;
132+ cinputs.messages .assign (chat_msgs_.cbegin (), chat_msgs_.cend ());
133+ cinputs.messages .push_back (new_msg);
134+ chat_msgs_.push_back (new_msg);
124135
125- if (! new_msg.tool_calls .empty ()) {
136+ bool tool_was_called = false ;
137+ if (! new_msg.tool_calls .empty ()) { // Call tool and re-prompt
126138 nlohmann::json result_array = nlohmann::json::array ();
127139 for (const auto & tc : new_msg.tool_calls ) {
128140 toolcall::result_set res = tc_client_->call (tc.name , tc.arguments , tc.id );
@@ -132,21 +144,28 @@ class chat_formatter {
132144 }
133145 }
134146 }
135- new_msg.content += result_array.dump (-1 );
147+ common_chat_msg toolcall_msg;
148+ toolcall_msg.role = " tool" ;
149+ toolcall_msg.content = result_array.dump (-1 );
150+
151+ cinputs.add_generation_prompt = true ;
152+ cinputs.messages .push_back (toolcall_msg);
153+ chat_msgs_.push_back (toolcall_msg);
154+
155+ tool_was_called = true ;
136156 }
137157
138- cinputs.messages .push_back (new_msg);
139158 common_chat_params cparams = common_chat_templates_apply (chat_templates_, cinputs);
159+ std::string formatted = cparams.prompt .substr (formatted_.size (), cparams.prompt .size ());
160+ formatted_ = cparams.prompt ;
140161
141- auto formatted = cparams.prompt ;
142- chat_msgs_.push_back (new_msg);
143162 LOG_DBG (" formatted: '%s'\n " , formatted.c_str ());
144163
145164#ifdef LLAMA_USE_TOOLCALL
146- if (chat_format_) * chat_format_ = cparams.format ;
165+ chat_format_ = cparams.format ;
147166 common_chat_grammar_to_sampler (&cparams, vocab_, ¶ms_.sampling );
148167#endif
149- return formatted;
168+ return chat_formatter::result{ std::move ( formatted), tool_was_called} ;
150169 }
151170
152171private:
@@ -157,7 +176,8 @@ class chat_formatter {
157176#ifdef LLAMA_USE_TOOLCALL
158177 const llama_vocab * vocab_;
159178 toolcall::client::ptr tc_client_;
160- common_chat_format * chat_format_;
179+ common_chat_format chat_format_;
180+ std::string formatted_;
161181#endif
162182};
163183
@@ -355,8 +375,7 @@ int main(int argc, char ** argv) {
355375 if (tc_client) {
356376 tc_client->initialize ();
357377 }
358- common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
359- chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client, &chat_format);
378+ chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client);
360379#else
361380 chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get ());
362381#endif
@@ -366,12 +385,12 @@ int main(int argc, char ** argv) {
366385 if (params.conversation_mode && params.enable_chat_template ) {
367386 if (!params.system_prompt .empty ()) {
368387 // format the system prompt (will use template default if empty)
369- chat_add_and_format (" system" , params.system_prompt , true );
388+ chat_add_and_format (" system" , params.system_prompt );
370389 }
371390
372391 if (!params.prompt .empty ()) {
373392 // format and append the user prompt
374- chat_add_and_format (" user" , params.prompt , true );
393+ chat_add_and_format (" user" , params.prompt );
375394 } else {
376395 waiting_for_first_input = true ;
377396 }
@@ -905,9 +924,15 @@ int main(int argc, char ** argv) {
905924 }
906925
907926 if (params.enable_chat_template ) {
908- chat_add_and_format (" assistant" , assistant_ss.str (), true );
909- is_interacting = true ;
910- LOG (" \n " );
927+ auto format_res = chat_add_and_format (" assistant" , assistant_ss.str ());
928+ if (format_res.tool_was_called ) {
929+ auto format_res_tok = common_tokenize (ctx, format_res.formatted , false , true );
930+ embd_inp.insert (embd_inp.end (), format_res_tok.begin (), format_res_tok.end ());
931+
932+ } else {
933+ is_interacting = true ;
934+ LOG (" \n " );
935+ }
911936 }
912937 }
913938 }
@@ -975,7 +1000,7 @@ int main(int argc, char ** argv) {
9751000
9761001 bool format_chat = params.conversation_mode && params.enable_chat_template ;
9771002 std::string user_inp = format_chat
978- ? chat_add_and_format (" user" , std::move (buffer))
1003+ ? chat_add_and_format (" user" , std::move (buffer)). formatted
9791004 : std::move (buffer);
9801005 // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
9811006 const auto line_pfx = common_tokenize (ctx, params.input_prefix , false , true );
0 commit comments