@@ -64,7 +64,7 @@ if n_kv_req > n_ctx {
6464
6565var buffer : [ CChar ] = [ ]
6666for id : llama_token in tokens {
67- print ( token_to_piece ( token: id) , terminator: " " )
67+ print ( token_to_piece ( token: id, buffer : & buffer ) ?? " " , terminator: " " )
6868}
6969
7070print ( " \n " )
@@ -101,6 +101,7 @@ if n_parallel > 1 {
101101}
102102
103103var streams : [ String ] = . init( repeating: " " , count: n_parallel)
104+ var streamBuffers : [ [ CChar ] ] = . init( repeating: [ ] , count: n_parallel)
104105var i_batch = [ Int32] ( repeating: batch. n_tokens - 1 , count: n_parallel)
105106
106107var n_cur = batch. n_tokens
@@ -157,12 +158,13 @@ while n_cur <= n_len {
157158 continue
158159 }
159160
161+ let nextStringPiece = token_to_piece ( token: new_token_id, buffer: & streamBuffers[ i] ) ?? " "
162+
160163 // if there is only one stream, we print immediately to stdout
161164 if n_parallel == 1 {
162- print ( token_to_piece ( token : new_token_id ) , terminator: " " )
165+ print ( nextStringPiece , terminator: " " )
163166 }
164-
165- streams [ i] += token_to_piece ( token: new_token_id)
167+ streams [ i] += nextStringPiece
166168
167169 // push this new token for next evaluation
168170 batch. token [ Int ( batch. n_tokens) ] = new_token_id
@@ -216,11 +218,38 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
216218 return swiftTokens
217219}
218220
219- private func token_to_piece( token: llama_token ) -> String {
220- let result = UnsafeMutablePointer< Int8> . allocate( capacity: 8 )
221- result. initialize ( repeating: Int8 ( 0 ) , count: 8 )
222- let _ = llama_token_to_piece ( model, token, result, 8 )
223- let resultStr = String ( cString: result)
224- result. deallocate ( )
225- return resultStr
221+ private func token_to_piece( token: llama_token , buffer: inout [ CChar ] ) -> String ? {
222+ var result = [ CChar] ( repeating: 0 , count: 8 )
223+ let nTokens = llama_token_to_piece ( model, token, & result, Int32 ( result. count) )
224+ if nTokens < 0 {
225+ if result. count >= - Int( nTokens) {
226+ result. removeLast ( - Int( nTokens) )
227+ } else {
228+ result. removeAll ( )
229+ }
230+ let check = llama_token_to_piece (
231+ model,
232+ token,
233+ & result,
234+ Int32 ( result. count)
235+ )
236+ assert ( check == nTokens)
237+ } else {
238+ result. removeLast ( result. count - Int( nTokens) )
239+ }
240+ if buffer. isEmpty, let utfString = String ( cString: result + [ 0 ] , encoding: . utf8) {
241+ return utfString
242+ } else {
243+ buffer. append ( contentsOf: result)
244+ let data = Data ( buffer. map { UInt8 ( bitPattern: $0) } )
245+ if buffer. count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
246+ buffer = [ ]
247+ }
248+ guard let bufferString = String ( data: data, encoding: . utf8) else {
249+ return nil
250+ }
251+ buffer = [ ]
252+ return bufferString
253+ }
254+ return nil
226255}
0 commit comments