55import struct
66import tempfile
77from io import BufferedWriter
8- from typing import Any , BinaryIO , Sequence
8+ from enum import Enum , auto
9+ from typing import Any , IO , Sequence
910
1011import numpy as np
1112
2122 TokenType ,
2223)
2324
25+ class WriterState (Enum ):
26+ EMPTY = auto ()
27+ HEADER = auto ()
28+ KV_DATA = auto ()
29+ TI_DATA = auto ()
30+
2431class GGUFWriter :
2532 fout : BufferedWriter
26- arch : str
27- offset_tensor = 0
28- data_alignment = GGUF_DEFAULT_ALIGNMENT
29- kv_data = b""
30- kv_data_count = 0
31- ti_data = b""
32- ti_data_count = 0
33- use_temp_file : bool
34- temp_file : tempfile .SpooledTemporaryFile [bytes ] | None = None
35- tensors : list [tuple [np .ndarray [Any , Any ], int ]]
33+ temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
34+ tensors : list [np .ndarray [Any , Any ]]
3635 _simple_value_packing = {
3736 GGUFValueType .UINT8 : "B" ,
3837 GGUFValueType .INT8 : "b" ,
@@ -60,27 +59,47 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool
6059 self .fout = open (path , "wb" )
6160 self .arch = arch
6261 self .endianess = endianess
63- self .add_architecture ()
62+ self .offset_tensor = 0
63+ self .data_alignment = GGUF_DEFAULT_ALIGNMENT
64+ self .kv_data = b""
65+ self .kv_data_count = 0
66+ self .ti_data = b""
67+ self .ti_data_count = 0
6468 self .use_temp_file = use_temp_file
69+ self .temp_file = None
6570 self .tensors = []
6671 print ("gguf: This GGUF file is for {0} Endian only"
6772 .format ("Big" if self .endianess == GGUFEndian .BIG else "Little" ))
73+ self .state = WriterState .EMPTY
74+
75+ self .add_architecture ()
6876
6977 def write_header_to_file (self ) -> None :
78+ if self .state is not WriterState .EMPTY :
79+ raise ValueError (f'Expected output file to be empty, got { self .state } ' )
80+
7081 self ._write_packed ("<I" , GGUF_MAGIC , skip_pack_prefix = True )
7182 self ._write_packed ("I" , GGUF_VERSION )
7283 self ._write_packed ("Q" , self .ti_data_count )
7384 self ._write_packed ("Q" , self .kv_data_count )
7485 self .flush ()
75- # print("tensors " + str( self.ti_data_count) + " kv " + str(self.kv_data_count))
86+ self .state = WriterState . HEADER
7687
7788 def write_kv_data_to_file (self ) -> None :
89+ if self .state is not WriterState .HEADER :
90+ raise ValueError (f'Expected output file to contain the header, got { self .state } ' )
91+
7892 self .fout .write (self .kv_data )
7993 self .flush ()
94+ self .state = WriterState .KV_DATA
8095
8196 def write_ti_data_to_file (self ) -> None :
97+ if self .state is not WriterState .KV_DATA :
98+ raise ValueError (f'Expected output file to contain KV data, got { self .state } ' )
99+
82100 self .fout .write (self .ti_data )
83101 self .flush ()
102+ self .state = WriterState .TI_DATA
84103
85104 def add_key (self , key : str ) -> None :
86105 self .add_val (key , GGUFValueType .STRING , add_vtype = False )
@@ -173,6 +192,9 @@ def ggml_pad(x: int, n: int) -> int:
173192 return ((x + n - 1 ) // n ) * n
174193
175194 def add_tensor_info (self , name : str , tensor_shape : Sequence [int ], tensor_dtype : np .dtype [np .float16 ] | np .dtype [np .float32 ], tensor_nbytes : int , raw_dtype : GGMLQuantizationType | None = None ) -> None :
195+ if self .state is not WriterState .EMPTY :
196+ raise ValueError (f'Expected output file to be empty, got { self .state } ' )
197+
176198 if raw_dtype is None and tensor_dtype not in (np .float32 , np .float16 ):
177199 raise ValueError ("Only F32 and F16 tensors are supported for now" )
178200
@@ -203,23 +225,21 @@ def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequenc
203225 shape : Sequence [int ] = raw_shape if raw_shape is not None else tensor .shape
204226 self .add_tensor_info (name , shape , tensor .dtype , tensor .nbytes , raw_dtype = raw_dtype )
205227
206- pad = GGUFWriter .ggml_pad (tensor .nbytes , self .data_alignment ) - tensor .nbytes
207-
208- if self .temp_file is None :
209- self .tensors .append ((tensor , pad ))
210- return
228+ if self .temp_file is None :
229+ self .tensors .append (tensor )
211230
212231 tensor .tofile (self .temp_file )
232+ self .write_padding (self .temp_file , tensor .nbytes )
213233
214- if pad != 0 :
215- self .temp_file .write (bytes ([0 ] * pad ))
216-
217- def write_padding (self , fp : BinaryIO , n : int , align : int | None = None ) -> None :
234+ def write_padding (self , fp : IO [bytes ], n : int , align : int | None = None ):
218235 pad = GGUFWriter .ggml_pad (n , align if align is not None else self .data_alignment ) - n
219236 if pad != 0 :
220237 fp .write (bytes ([0 ] * pad ))
221238
222239 def write_tensor_data (self , tensor : np .ndarray [Any , Any ]) -> None :
240+ if self .state is not WriterState .TI_DATA :
241+ raise ValueError (f'Expected output file to contain tensor info, got { self .state } ' )
242+
223243 if self .endianess == GGUFEndian .BIG :
224244 tensor .byteswap (inplace = True )
225245 self .write_padding (self .fout , self .fout .tell ())
@@ -232,10 +252,13 @@ def write_tensors_to_file(self) -> None:
232252 self .write_padding (self .fout , self .fout .tell ())
233253
234254 if self .temp_file is None :
235- for (currtensor , currpad ) in self .tensors :
236- currtensor .tofile (self .fout )
237- if currpad != 0 :
238- self .fout .write (bytes ([0 ] * currpad ))
255+ while True :
256+ try :
257+ tensor = self .tensors .pop (0 )
258+ except IndexError :
259+ break
260+ tensor .tofile (self .fout )
261+ self .write_padding (self .fout , tensor .nbytes )
239262 return
240263
241264 self .temp_file .seek (0 )
0 commit comments