11use std:: borrow:: Cow ;
2- use std:: sync:: Arc ;
2+ use std:: f32:: consts:: E ;
3+ use std:: marker:: PhantomData ;
4+ use std:: ptr:: { self , NonNull } ;
5+ use std:: sync:: { Arc , Mutex , MutexGuard , RwLock , RwLockReadGuard } ;
36
47use pyo3:: exceptions:: { PyAttributeError , PyRecursionError , PyRuntimeError } ;
58use pyo3:: gc:: PyVisit ;
@@ -11,6 +14,7 @@ use pyo3::PyTraverseError;
1114use pyo3:: types:: PyString ;
1215
1316use crate :: definitions:: DefinitionsBuilder ;
17+ use crate :: serializers:: extra;
1418use crate :: tools:: SchemaDict ;
1519use crate :: tools:: { function_name, py_err, py_error_type} ;
1620use crate :: { PydanticOmit , PydanticSerializationUnexpectedValue } ;
@@ -393,7 +397,9 @@ impl FunctionWrapSerializer {
393397 ) -> PyResult < ( bool , PyObject ) > {
394398 let py = value. py ( ) ;
395399 if self . when_used . should_use ( value, extra) {
396- let serialize = SerializationCallable :: new ( & self . serializer , include, exclude, extra) ;
400+ let extra_ref_guard = ExtraRef :: new ( extra) ;
401+ let serialize =
402+ SerializationCallable :: new ( & self . serializer , include, exclude, extra_ref_guard. inner ( ) . clone ( ) ) ;
397403 let v = if self . is_field_serializer {
398404 if let Some ( model) = extra. model {
399405 if self . info_arg {
@@ -434,11 +440,56 @@ impl_py_gc_traverse!(FunctionWrapSerializer {
434440
435441function_type_serializer ! ( FunctionWrapSerializer ) ;
436442
443+ /// A wrapper around `&Extra` which drops the lifetime, in order to be stored inside a Python object.
444+ #[ derive( Clone ) ]
445+ struct ExtraRef {
446+ value : Arc < RwLock < Option < * const Extra < ' static > > > > ,
447+ }
448+
449+ // Safety: `&Extra` is `Send + Sync`
450+ unsafe impl Send for ExtraRef { }
451+ unsafe impl Sync for ExtraRef { }
452+
453+ impl ExtraRef {
454+ fn new < ' a > ( extra : & ' a Extra < ' a > ) -> ExtraRefGuard < ' a > {
455+ ExtraRefGuard (
456+ ExtraRef {
457+ value : Arc :: new ( RwLock :: new ( Some ( ptr:: from_ref ( extra) . cast ( ) ) ) ) ,
458+ } ,
459+ PhantomData ,
460+ )
461+ }
462+
463+ fn map < R > ( & self , f : impl FnOnce ( & Extra < ' _ > ) -> R ) -> Option < R > {
464+ // FIXME: deal with lock poisoning?, use try_read
465+ let guard = self . value . read ( ) . unwrap ( ) ;
466+ guard. as_ref ( ) . map ( |ptr| {
467+ // Safety: we ensure that the pointer is valid while `ExtraRef` is alive
468+ let extra: & Extra = unsafe { & * * ptr } ;
469+ f ( extra)
470+ } )
471+ }
472+ }
473+
474+ struct ExtraRefGuard < ' a > ( ExtraRef , PhantomData < & ' a Extra < ' a > > ) ;
475+
476+ impl ExtraRefGuard < ' _ > {
477+ fn inner ( & self ) -> & ExtraRef {
478+ & self . 0
479+ }
480+ }
481+
482+ impl Drop for ExtraRefGuard < ' _ > {
483+ fn drop ( & mut self ) {
484+ let mut guard = self . 0 . value . write ( ) . unwrap ( ) ;
485+ * guard = None ;
486+ }
487+ }
488+
437489#[ pyclass( module = "pydantic_core._pydantic_core" ) ]
438- #[ cfg_attr( debug_assertions, derive( Debug ) ) ]
439490pub ( crate ) struct SerializationCallable {
440491 serializer : Arc < CombinedSerializer > ,
441- extra_owned : ExtraOwned ,
492+ extra : ExtraRef ,
442493 filter : AnyFilter ,
443494 include : Option < PyObject > ,
444495 exclude : Option < PyObject > ,
@@ -449,11 +500,11 @@ impl SerializationCallable {
449500 serializer : & Arc < CombinedSerializer > ,
450501 include : Option < & Bound < ' _ , PyAny > > ,
451502 exclude : Option < & Bound < ' _ , PyAny > > ,
452- extra : & Extra ,
503+ extra : ExtraRef ,
453504 ) -> Self {
454505 Self {
455506 serializer : serializer. clone ( ) ,
456- extra_owned : ExtraOwned :: new ( extra) ,
507+ extra : extra,
457508 filter : AnyFilter :: new ( ) ,
458509 include : include. map ( |v| v. clone ( ) . unbind ( ) ) ,
459510 exclude : exclude. map ( |v| v. clone ( ) . unbind ( ) ) ,
@@ -467,24 +518,22 @@ impl SerializationCallable {
467518 if let Some ( exclude) = & self . exclude {
468519 visit. call ( exclude) ?;
469520 }
470- if let Some ( model) = & self . extra_owned . model {
471- visit. call ( model) ?;
472- }
473- if let Some ( fallback) = & self . extra_owned . fallback {
474- visit. call ( fallback) ?;
475- }
476- if let Some ( context) = & self . extra_owned . context {
477- visit. call ( context) ?;
478- }
521+ self . extra
522+ . map ( |extra| {
523+ // FIXME: not sound to get .read() of extra inside GC, probably need to make `Extra` not
524+ // have the `'py` lifetime
525+ visit. call ( extra. model . map ( Bound :: as_unbound) ) ?;
526+ visit. call ( extra. fallback . map ( Bound :: as_unbound) ) ?;
527+ visit. call ( extra. context . map ( Bound :: as_unbound) ) ?;
528+ Ok ( ( ) )
529+ } )
530+ . transpose ( ) ?;
479531 Ok ( ( ) )
480532 }
481533
482534 fn __clear__ ( & mut self ) {
483535 self . include = None ;
484536 self . exclude = None ;
485- self . extra_owned . model = None ;
486- self . extra_owned . fallback = None ;
487- self . extra_owned . context = None ;
488537 }
489538}
490539
@@ -503,28 +552,40 @@ impl SerializationCallable {
503552
504553 let include = self . include . as_ref ( ) . map ( |o| o. bind ( py) ) ;
505554 let exclude = self . exclude . as_ref ( ) . map ( |o| o. bind ( py) ) ;
506- let extra = self . extra_owned . to_extra ( py) ;
507555
508- if let Some ( index_key) = index_key {
509- let filter = if let Ok ( index) = index_key. extract :: < usize > ( ) {
510- self . filter . index_filter ( index, include, exclude, None ) ?
511- } else {
512- self . filter . key_filter ( index_key, include, exclude) ?
513- } ;
514- if let Some ( ( next_include, next_exclude) ) = filter {
515- let v =
516- self . serializer
517- . to_python_no_infer ( value, next_include. as_ref ( ) , next_exclude. as_ref ( ) , & extra) ?;
518- extra. warnings . final_check ( py) ?;
519- Ok ( Some ( v) )
520- } else {
521- Err ( PydanticOmit :: new_err ( ) )
522- }
523- } else {
524- let v = self . serializer . to_python_no_infer ( value, include, exclude, & extra) ?;
525- extra. warnings . final_check ( py) ?;
526- Ok ( Some ( v) )
527- }
556+ // FIXME: the &T is not sound here, since the guard is dropped at the end of this statement.
557+ // Probably need to have a .map() method to avoid scope leak?
558+ self . extra
559+ . map ( |extra| {
560+ if let Some ( index_key) = index_key {
561+ let filter = if let Ok ( index) = index_key. extract :: < usize > ( ) {
562+ self . filter . index_filter ( index, include, exclude, None ) ?
563+ } else {
564+ self . filter . key_filter ( index_key, include, exclude) ?
565+ } ;
566+ if let Some ( ( next_include, next_exclude) ) = filter {
567+ let v = self . serializer . to_python_no_infer (
568+ value,
569+ next_include. as_ref ( ) ,
570+ next_exclude. as_ref ( ) ,
571+ & extra,
572+ ) ?;
573+ extra. warnings . final_check ( py) ?;
574+ Ok ( Some ( v) )
575+ } else {
576+ Err ( PydanticOmit :: new_err ( ) )
577+ }
578+ } else {
579+ let v = self . serializer . to_python_no_infer ( value, include, exclude, & extra) ?;
580+ extra. warnings . final_check ( py) ?;
581+ Ok ( Some ( v) )
582+ }
583+ } )
584+ . unwrap_or_else ( || {
585+ Err ( PyRuntimeError :: new_err (
586+ "Attempted to use SerializationCallable after its wrap validation context was exited" ,
587+ ) )
588+ } )
528589 }
529590
530591 fn __repr__ ( & self ) -> PyResult < String > {
0 commit comments