Skip to content

Commit 4501c04

Browse files
committed
experiment with "ref" form of extra
1 parent 313d199 commit 4501c04

File tree

1 file changed

+100
-39
lines changed

1 file changed

+100
-39
lines changed

src/serializers/type_serializers/function.rs

Lines changed: 100 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::borrow::Cow;
2-
use std::sync::Arc;
2+
use std::f32::consts::E;
3+
use std::marker::PhantomData;
4+
use std::ptr::{self, NonNull};
5+
use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
36

47
use pyo3::exceptions::{PyAttributeError, PyRecursionError, PyRuntimeError};
58
use pyo3::gc::PyVisit;
@@ -11,6 +14,7 @@ use pyo3::PyTraverseError;
1114
use pyo3::types::PyString;
1215

1316
use crate::definitions::DefinitionsBuilder;
17+
use crate::serializers::extra;
1418
use crate::tools::SchemaDict;
1519
use crate::tools::{function_name, py_err, py_error_type};
1620
use crate::{PydanticOmit, PydanticSerializationUnexpectedValue};
@@ -393,7 +397,9 @@ impl FunctionWrapSerializer {
393397
) -> PyResult<(bool, PyObject)> {
394398
let py = value.py();
395399
if self.when_used.should_use(value, extra) {
396-
let serialize = SerializationCallable::new(&self.serializer, include, exclude, extra);
400+
let extra_ref_guard = ExtraRef::new(extra);
401+
let serialize =
402+
SerializationCallable::new(&self.serializer, include, exclude, extra_ref_guard.inner().clone());
397403
let v = if self.is_field_serializer {
398404
if let Some(model) = extra.model {
399405
if self.info_arg {
@@ -434,11 +440,56 @@ impl_py_gc_traverse!(FunctionWrapSerializer {
434440

435441
function_type_serializer!(FunctionWrapSerializer);
436442

443+
/// A wrapper around `&Extra` which drops the lifetime, in order to be stored inside a Python object.
444+
#[derive(Clone)]
445+
struct ExtraRef {
446+
value: Arc<RwLock<Option<*const Extra<'static>>>>,
447+
}
448+
449+
// Safety: `&Extra` is `Send + Sync`
450+
unsafe impl Send for ExtraRef {}
451+
unsafe impl Sync for ExtraRef {}
452+
453+
impl ExtraRef {
454+
fn new<'a>(extra: &'a Extra<'a>) -> ExtraRefGuard<'a> {
455+
ExtraRefGuard(
456+
ExtraRef {
457+
value: Arc::new(RwLock::new(Some(ptr::from_ref(extra).cast()))),
458+
},
459+
PhantomData,
460+
)
461+
}
462+
463+
fn map<R>(&self, f: impl FnOnce(&Extra<'_>) -> R) -> Option<R> {
464+
// FIXME: deal with lock poisoning?, use try_read
465+
let guard = self.value.read().unwrap();
466+
guard.as_ref().map(|ptr| {
467+
// Safety: we ensure that the pointer is valid while `ExtraRef` is alive
468+
let extra: &Extra = unsafe { &**ptr };
469+
f(extra)
470+
})
471+
}
472+
}
473+
474+
struct ExtraRefGuard<'a>(ExtraRef, PhantomData<&'a Extra<'a>>);
475+
476+
impl ExtraRefGuard<'_> {
477+
fn inner(&self) -> &ExtraRef {
478+
&self.0
479+
}
480+
}
481+
482+
impl Drop for ExtraRefGuard<'_> {
483+
fn drop(&mut self) {
484+
let mut guard = self.0.value.write().unwrap();
485+
*guard = None;
486+
}
487+
}
488+
437489
#[pyclass(module = "pydantic_core._pydantic_core")]
438-
#[cfg_attr(debug_assertions, derive(Debug))]
439490
pub(crate) struct SerializationCallable {
440491
serializer: Arc<CombinedSerializer>,
441-
extra_owned: ExtraOwned,
492+
extra: ExtraRef,
442493
filter: AnyFilter,
443494
include: Option<PyObject>,
444495
exclude: Option<PyObject>,
@@ -449,11 +500,11 @@ impl SerializationCallable {
449500
serializer: &Arc<CombinedSerializer>,
450501
include: Option<&Bound<'_, PyAny>>,
451502
exclude: Option<&Bound<'_, PyAny>>,
452-
extra: &Extra,
503+
extra: ExtraRef,
453504
) -> Self {
454505
Self {
455506
serializer: serializer.clone(),
456-
extra_owned: ExtraOwned::new(extra),
507+
extra: extra,
457508
filter: AnyFilter::new(),
458509
include: include.map(|v| v.clone().unbind()),
459510
exclude: exclude.map(|v| v.clone().unbind()),
@@ -467,24 +518,22 @@ impl SerializationCallable {
467518
if let Some(exclude) = &self.exclude {
468519
visit.call(exclude)?;
469520
}
470-
if let Some(model) = &self.extra_owned.model {
471-
visit.call(model)?;
472-
}
473-
if let Some(fallback) = &self.extra_owned.fallback {
474-
visit.call(fallback)?;
475-
}
476-
if let Some(context) = &self.extra_owned.context {
477-
visit.call(context)?;
478-
}
521+
self.extra
522+
.map(|extra| {
523+
// FIXME: not sound to get .read() of extra inside GC, probably need to make `Extra` not
524+
// have the `'py` lifetime
525+
visit.call(extra.model.map(Bound::as_unbound))?;
526+
visit.call(extra.fallback.map(Bound::as_unbound))?;
527+
visit.call(extra.context.map(Bound::as_unbound))?;
528+
Ok(())
529+
})
530+
.transpose()?;
479531
Ok(())
480532
}
481533

482534
fn __clear__(&mut self) {
483535
self.include = None;
484536
self.exclude = None;
485-
self.extra_owned.model = None;
486-
self.extra_owned.fallback = None;
487-
self.extra_owned.context = None;
488537
}
489538
}
490539

@@ -503,28 +552,40 @@ impl SerializationCallable {
503552

504553
let include = self.include.as_ref().map(|o| o.bind(py));
505554
let exclude = self.exclude.as_ref().map(|o| o.bind(py));
506-
let extra = self.extra_owned.to_extra(py);
507555

508-
if let Some(index_key) = index_key {
509-
let filter = if let Ok(index) = index_key.extract::<usize>() {
510-
self.filter.index_filter(index, include, exclude, None)?
511-
} else {
512-
self.filter.key_filter(index_key, include, exclude)?
513-
};
514-
if let Some((next_include, next_exclude)) = filter {
515-
let v =
516-
self.serializer
517-
.to_python_no_infer(value, next_include.as_ref(), next_exclude.as_ref(), &extra)?;
518-
extra.warnings.final_check(py)?;
519-
Ok(Some(v))
520-
} else {
521-
Err(PydanticOmit::new_err())
522-
}
523-
} else {
524-
let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?;
525-
extra.warnings.final_check(py)?;
526-
Ok(Some(v))
527-
}
556+
// FIXME: the &T is not sound here, since the guard is dropped at the end of this statement.
557+
// Probably need to have a .map() method to avoid scope leak?
558+
self.extra
559+
.map(|extra| {
560+
if let Some(index_key) = index_key {
561+
let filter = if let Ok(index) = index_key.extract::<usize>() {
562+
self.filter.index_filter(index, include, exclude, None)?
563+
} else {
564+
self.filter.key_filter(index_key, include, exclude)?
565+
};
566+
if let Some((next_include, next_exclude)) = filter {
567+
let v = self.serializer.to_python_no_infer(
568+
value,
569+
next_include.as_ref(),
570+
next_exclude.as_ref(),
571+
&extra,
572+
)?;
573+
extra.warnings.final_check(py)?;
574+
Ok(Some(v))
575+
} else {
576+
Err(PydanticOmit::new_err())
577+
}
578+
} else {
579+
let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?;
580+
extra.warnings.final_check(py)?;
581+
Ok(Some(v))
582+
}
583+
})
584+
.unwrap_or_else(|| {
585+
Err(PyRuntimeError::new_err(
586+
"Attempted to use SerializationCallable after its wrap validation context was exited",
587+
))
588+
})
528589
}
529590

530591
fn __repr__(&self) -> PyResult<String> {

0 commit comments

Comments
 (0)