From e34a5eb05f51889882e254e36050760ecc7e9105 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 23 Aug 2023 16:58:45 +0100 Subject: [PATCH 001/550] fix big int subclass extraction (#919) --- src/input/input_python.rs | 6 ++---- src/input/return_enums.rs | 9 +++++++++ tests/validators/test_int.py | 2 ++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 24f37504f..5ebe9025a 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -304,8 +304,7 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::IntType, self)) } else { // force to an int to upcast to a pure python int - let int = self.extract::()?; - Ok(EitherInt::I64(int)) + EitherInt::upcast(self) } } else { Err(ValError::new(ErrorTypeDefaults::IntType, self)) @@ -320,8 +319,7 @@ impl<'a> Input<'a> for PyAny { str_as_int(self, &cow_str) } else if PyInt::is_type_of(self) { // force to an int to upcast to a pure python int to maintain current behaviour - let int = self.extract::()?; - Ok(EitherInt::I64(int)) + EitherInt::upcast(self) } else if let Ok(float) = self.extract::() { float_as_int(self, float) } else { diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 483b1918c..d0f3ba45d 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -820,6 +820,15 @@ pub enum EitherInt<'a> { } impl<'a> EitherInt<'a> { + pub fn upcast(py_any: &'a PyAny) -> ValResult { + // Safety: we know that py_any is a python int + if let Ok(int_64) = py_any.extract::() { + Ok(Self::I64(int_64)) + } else { + let big_int: BigInt = py_any.extract()?; + Ok(Self::BigInt(big_int)) + } + } pub fn into_i64(self, py: Python<'a>) -> ValResult<'a, i64> { match self { EitherInt::I64(i) => Ok(i), diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 4ef08081e..44a806118 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -421,7 +421,9 @@ def test_int_subclass() -> None: assert type(v_strict) == int assert v.validate_python(IntSubclass(1136885225876639845)) == 1136885225876639845 + assert v.validate_python(IntSubclass(i64_max + 7)) == i64_max + 7 assert v.validate_python(IntSubclass(1136885225876639845), strict=True) == 1136885225876639845 + assert v.validate_python(IntSubclass(i64_max + 7), strict=True) == i64_max + 7 def test_int_subclass_constraint() -> None: From c086caec1a200417f19850244282c06b5d4d1650 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 23 Aug 2023 11:19:17 -0500 Subject: [PATCH 002/550] Bump pydantic-core to 2.6.3 (#920) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f10f86d4a..25119d19c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.6.2" +version = "2.6.3" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 93637dda9..b16624acc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.6.2" +version = "2.6.3" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From b076a8f41caeebc5df0e0d40fc54696d86ed6227 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 23 Aug 2023 11:43:29 -0500 Subject: [PATCH 003/550] Allow customizing serialization of extras (#911) --- python/pydantic_core/core_schema.py | 24 +++++++-------- src/serializers/fields.rs | 13 ++++++-- src/serializers/type_serializers/dataclass.rs | 2 +- src/serializers/type_serializers/model.rs | 9 +++++- src/serializers/type_serializers/tuple.rs | 4 +-- .../type_serializers/typed_dict.rs | 11 ++++++- src/validators/dataclass.rs | 25 +++++++++++++++- src/validators/model_fields.rs | 16 +++++----- src/validators/tuple.rs | 18 +++++------ src/validators/typed_dict.rs | 14 ++++----- tests/benchmarks/test_micro_benchmarks.py | 2 +- tests/serializers/test_dataclasses.py | 30 ++++++++++++++++++- tests/serializers/test_list_tuple.py | 4 +-- tests/serializers/test_model.py | 26 +++++++++++++++- tests/serializers/test_typed_dict.py | 15 ++++++++++ tests/test_json.py | 2 +- tests/validators/test_model_fields.py | 16 +++++----- tests/validators/test_tuple.py | 6 ++-- tests/validators/test_typed_dict.py | 16 +++++----- tests/validators/test_with_default.py | 2 +- 20 files changed, 185 insertions(+), 70 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 725d41eee..74442b44c 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1366,7 +1366,7 @@ def list_schema( class TuplePositionalSchema(TypedDict, total=False): type: Required[Literal['tuple-positional']] items_schema: Required[List[CoreSchema]] - extra_schema: CoreSchema + extras_schema: CoreSchema strict: bool ref: str metadata: Any @@ -1376,7 +1376,7 @@ class TuplePositionalSchema(TypedDict, total=False): def tuple_positional_schema( items_schema: list[CoreSchema], *, - extra_schema: CoreSchema | None = None, + extras_schema: CoreSchema | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, @@ -1397,7 +1397,7 @@ def tuple_positional_schema( Args: items_schema: The value must be a tuple with items that match these schemas - extra_schema: The value must be a tuple with items that match this schema + extras_schema: The value must be a tuple with items that match this schema This was inspired by JSON schema's `prefixItems` and `items` fields. In python's `typing.Tuple`, you can't specify a type for "extra" items -- they must all be the same type if the length is variable. So this field won't be set from a `typing.Tuple` annotation on a pydantic model. @@ -1409,7 +1409,7 @@ def tuple_positional_schema( return _dict_not_none( type='tuple-positional', items_schema=items_schema, - extra_schema=extra_schema, + extras_schema=extras_schema, strict=strict, ref=ref, metadata=metadata, @@ -2829,7 +2829,7 @@ class TypedDictSchema(TypedDict, total=False): fields: Required[Dict[str, TypedDictField]] computed_fields: List[ComputedField] strict: bool - extra_validator: CoreSchema + extras_schema: CoreSchema # all these values can be set via config, equivalent fields have `typed_dict_` prefix extra_behavior: ExtraBehavior total: bool # default: True @@ -2845,7 +2845,7 @@ def typed_dict_schema( *, computed_fields: list[ComputedField] | None = None, strict: bool | None = None, - extra_validator: CoreSchema | None = None, + extras_schema: CoreSchema | None = None, extra_behavior: ExtraBehavior | None = None, total: bool | None = None, populate_by_name: bool | None = None, @@ -2871,7 +2871,7 @@ def typed_dict_schema( fields: The fields to use for the typed dict computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model strict: Whether the typed dict is strict - extra_validator: The extra validator to use for the typed dict + extras_schema: The extra validator to use for the typed dict ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core extra_behavior: The extra behavior to use for the typed dict @@ -2884,7 +2884,7 @@ def typed_dict_schema( fields=fields, computed_fields=computed_fields, strict=strict, - extra_validator=extra_validator, + extras_schema=extras_schema, extra_behavior=extra_behavior, total=total, populate_by_name=populate_by_name, @@ -2948,7 +2948,7 @@ class ModelFieldsSchema(TypedDict, total=False): model_name: str computed_fields: List[ComputedField] strict: bool - extra_validator: CoreSchema + extras_schema: CoreSchema # all these values can be set via config, equivalent fields have `typed_dict_` prefix extra_behavior: ExtraBehavior populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 @@ -2964,7 +2964,7 @@ def model_fields_schema( model_name: str | None = None, computed_fields: list[ComputedField] | None = None, strict: bool | None = None, - extra_validator: CoreSchema | None = None, + extras_schema: CoreSchema | None = None, extra_behavior: ExtraBehavior | None = None, populate_by_name: bool | None = None, from_attributes: bool | None = None, @@ -2991,7 +2991,7 @@ def model_fields_schema( model_name: The name of the model, used for error messages, defaults to "Model" computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model strict: Whether the typed dict is strict - extra_validator: The extra validator to use for the typed dict + extras_schema: The extra validator to use for the typed dict ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core extra_behavior: The extra behavior to use for the typed dict @@ -3005,7 +3005,7 @@ def model_fields_schema( model_name=model_name, computed_fields=computed_fields, strict=strict, - extra_validator=extra_validator, + extras_schema=extras_schema, extra_behavior=extra_behavior, populate_by_name=populate_by_name, from_attributes=from_attributes, diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index c8bf67971..f48f8e0b9 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -94,6 +94,7 @@ pub struct GeneralFieldsSerializer { fields: AHashMap, computed_fields: Option, mode: FieldsMode, + extra_serializer: Option>, // isize because we look up filter via `.hash()` which returns an isize filter: SchemaFilter, required_fields: usize, @@ -103,12 +104,14 @@ impl GeneralFieldsSerializer { pub(super) fn new( fields: AHashMap, mode: FieldsMode, + extra_serializer: Option, computed_fields: Option, ) -> Self { let required_fields = fields.values().filter(|f| f.required).count(); Self { fields, mode, + extra_serializer: extra_serializer.map(Box::new), filter: SchemaFilter::default(), computed_fields, required_fields, @@ -205,7 +208,10 @@ impl TypeSerializer for GeneralFieldsSerializer { used_req_fields += 1; } } else if self.mode == FieldsMode::TypedDictAllow { - let value = infer_to_python(value, next_include, next_exclude, &extra)?; + let value = match &self.extra_serializer { + Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?, + None => infer_to_python(value, next_include, next_exclude, &extra)?, + }; output_dict.set_item(key, value)?; } else if extra.check == SerCheck::Strict { return Err(PydanticSerializationUnexpectedValue::new_err(None)); @@ -227,7 +233,10 @@ impl TypeSerializer for GeneralFieldsSerializer { continue; } if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? { - let value = infer_to_python(value, next_include, next_exclude, &td_extra)?; + let value = match &self.extra_serializer { + Some(serializer) => serializer.to_python(value, next_include, next_exclude, extra)?, + None => infer_to_python(value, next_include, next_exclude, extra)?, + }; output_dict.set_item(key, value)?; } } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index aa3a180da..787e267dd 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -55,7 +55,7 @@ impl BuildSerializer for DataclassArgsBuilder { let computed_fields = ComputedFields::new(schema, config, definitions)?; - Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into()) + Ok(GeneralFieldsSerializer::new(fields, fields_mode, None, computed_fields).into()) } } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 9d14dd16b..8a2eeb4e1 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -11,6 +11,7 @@ use super::{ CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, TypeSerializer, }; +use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; use crate::serializers::errors::PydanticSerializationUnexpectedValue; @@ -38,6 +39,12 @@ impl BuildSerializer for ModelFieldsBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + (Some(v), FieldsMode::ModelExtra) => Some(CombinedSerializer::build(v.extract()?, config, definitions)?), + (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), + (_, _) => None, + }; + for (key, value) in fields_dict { let key_py: &PyString = key.downcast()?; let key: String = key_py.extract()?; @@ -60,7 +67,7 @@ impl BuildSerializer for ModelFieldsBuilder { let computed_fields = ComputedFields::new(schema, config, definitions)?; - Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into()) + Ok(GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into()) } } diff --git a/src/serializers/type_serializers/tuple.rs b/src/serializers/type_serializers/tuple.rs index 46d7ef775..00c61e250 100644 --- a/src/serializers/type_serializers/tuple.rs +++ b/src/serializers/type_serializers/tuple.rs @@ -159,8 +159,8 @@ impl BuildSerializer for TuplePositionalSerializer { let py = schema.py(); let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?; - let extra_serializer = match schema.get_as::<&PyDict>(intern!(py, "extra_schema"))? { - Some(extra_schema) => CombinedSerializer::build(extra_schema, config, definitions)?, + let extra_serializer = match schema.get_as::<&PyDict>(intern!(py, "extras_schema"))? { + Some(extras_schema) => CombinedSerializer::build(extras_schema, config, definitions)?, None => AnySerializer::build(schema, config, definitions)?, }; let items_serializers: Vec = items diff --git a/src/serializers/type_serializers/typed_dict.rs b/src/serializers/type_serializers/typed_dict.rs index 4d0cf4a3c..5967738ae 100644 --- a/src/serializers/type_serializers/typed_dict.rs +++ b/src/serializers/type_serializers/typed_dict.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyString}; use ahash::AHashMap; +use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -34,6 +35,14 @@ impl BuildSerializer for TypedDictBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + (Some(v), FieldsMode::TypedDictAllow) => { + Some(CombinedSerializer::build(v.extract()?, config, definitions)?) + } + (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), + (_, _) => None, + }; + for (key, value) in fields_dict { let key_py: &PyString = key.downcast()?; let key: String = key_py.extract()?; @@ -56,6 +65,6 @@ impl BuildSerializer for TypedDictBuilder { let computed_fields = ComputedFields::new(schema, config, definitions)?; - Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into()) + Ok(GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into()) } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index b5f8dc6ba..7b7be282e 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -38,6 +38,7 @@ pub struct DataclassArgsValidator { dataclass_name: String, validator_name: String, extra_behavior: ExtraBehavior, + extras_validator: Option>, loc_by_alias: bool, } @@ -55,6 +56,12 @@ impl BuildValidator for DataclassArgsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; + let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), + (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), + (_, _) => None, + }; + let fields_schema: &PyList = schema.get_as_req(intern!(py, "fields"))?; let mut fields: Vec = Vec::with_capacity(fields_schema.len()); @@ -118,6 +125,7 @@ impl BuildValidator for DataclassArgsValidator { dataclass_name, validator_name, extra_behavior, + extras_validator, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), } .into()) @@ -267,7 +275,22 @@ impl Validator for DataclassArgsValidator { } ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { - output_dict.set_item(either_str.as_py_string(py), value)? + if let Some(ref validator) = self.extras_validator { + match validator.validate(py, value, state) { + Ok(value) => output_dict + .set_item(either_str.as_py_string(py), value)?, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(err.with_outer_location( + raw_key.as_loc_item(), + )); + } + } + Err(err) => return Err(err), + } + } else { + output_dict.set_item(either_str.as_py_string(py), value)? + } } } } diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 131118cdb..29e9522f2 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -36,7 +36,7 @@ pub struct ModelFieldsValidator { fields: Vec, model_name: String, extra_behavior: ExtraBehavior, - extra_validator: Option>, + extras_validator: Option>, strict: bool, from_attributes: bool, loc_by_alias: bool, @@ -58,9 +58,9 @@ impl BuildValidator for ModelFieldsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extra_validator = match (schema.get_item(intern!(py, "extra_validator")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), - (Some(_), _) => return py_schema_err!("extra_validator can only be used if extra_behavior=allow"), + (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, }; let model_name: String = schema @@ -102,7 +102,7 @@ impl BuildValidator for ModelFieldsValidator { fields, model_name, extra_behavior, - extra_validator, + extras_validator, strict, from_attributes, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), @@ -113,7 +113,7 @@ impl BuildValidator for ModelFieldsValidator { impl_py_gc_traverse!(ModelFieldsValidator { fields, - extra_validator + extras_validator }); impl Validator for ModelFieldsValidator { @@ -265,7 +265,7 @@ impl Validator for ModelFieldsValidator { ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { let py_key = either_str.as_py_string(py); - if let Some(ref validator) = self.extra_validator { + if let Some(ref validator) = self.extras_validator { match validator.validate(py, value, state) { Ok(value) => { model_extra_dict.set_item(py_key, value)?; @@ -373,7 +373,7 @@ impl Validator for ModelFieldsValidator { // For models / typed dicts we forbid assigning extra attributes // unless the user explicitly set extra_behavior to 'allow' match self.extra_behavior { - ExtraBehavior::Allow => match self.extra_validator { + ExtraBehavior::Allow => match self.extras_validator { Some(ref validator) => prepare_result( state.with_new_extra(new_extra, |state| validator.validate(py, field_value, state)), ), @@ -430,7 +430,7 @@ impl Validator for ModelFieldsValidator { self.fields .iter_mut() .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extra_validator { + match &mut self.extras_validator { Some(v) => v.complete(definitions), None => Ok(()), } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 1a3033d48..07887fddb 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -91,7 +91,7 @@ impl Validator for TupleVariableValidator { pub struct TuplePositionalValidator { strict: bool, items_validators: Vec, - extra_validator: Option>, + extras_validator: Option>, name: String, } @@ -117,7 +117,7 @@ impl BuildValidator for TuplePositionalValidator { Ok(Self { strict: is_strict(schema, config)?, items_validators: validators, - extra_validator: match schema.get_item(intern!(py, "extra_schema")) { + extras_validator: match schema.get_item(intern!(py, "extras_schema")) { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, @@ -134,7 +134,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, state: &mut ValidationState, output: &mut Vec, errors: &mut Vec>, - extra_validator: &Option>, + extras_validator: &Option>, items_validators: &[CombinedValidator], collection_iter: &mut T, collection_len: Option, @@ -160,8 +160,8 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, } for (index, result) in collection_iter.enumerate() { let item = result?; - match extra_validator { - Some(ref extra_validator) => match extra_validator.validate(py, item, state) { + match extras_validator { + Some(ref extras_validator) => match extras_validator.validate(py, item, state) { Ok(item) => output.push(item), Err(ValError::LineErrors(line_errors)) => { errors.extend( @@ -193,7 +193,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, impl_py_gc_traverse!(TuplePositionalValidator { items_validators, - extra_validator + extras_validator }); impl Validator for TuplePositionalValidator { @@ -218,7 +218,7 @@ impl Validator for TuplePositionalValidator { state, &mut output, &mut errors, - &self.extra_validator, + &self.extras_validator, &self.items_validators, &mut $collection_iter, collection_len, @@ -252,7 +252,7 @@ impl Validator for TuplePositionalValidator { .any(|v| v.different_strict_behavior(definitions, true)) { true - } else if let Some(ref v) = self.extra_validator { + } else if let Some(ref v) = self.extras_validator { v.different_strict_behavior(definitions, true) } else { false @@ -270,7 +270,7 @@ impl Validator for TuplePositionalValidator { self.items_validators .iter_mut() .try_for_each(|v| v.complete(definitions))?; - match &mut self.extra_validator { + match &mut self.extras_validator { Some(v) => v.complete(definitions), None => Ok(()), } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 69fb7f641..a095a52f1 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -35,7 +35,7 @@ impl_py_gc_traverse!(TypedDictField { validator }); pub struct TypedDictValidator { fields: Vec, extra_behavior: ExtraBehavior, - extra_validator: Option>, + extras_validator: Option>, strict: bool, loc_by_alias: bool, } @@ -61,9 +61,9 @@ impl BuildValidator for TypedDictValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extra_validator = match (schema.get_item(intern!(py, "extra_validator")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), - (Some(_), _) => return py_schema_err!("extra_validator can only be used if extra_behavior=allow"), + (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, }; @@ -129,7 +129,7 @@ impl BuildValidator for TypedDictValidator { Ok(Self { fields, extra_behavior, - extra_validator, + extras_validator, strict, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), } @@ -139,7 +139,7 @@ impl BuildValidator for TypedDictValidator { impl_py_gc_traverse!(TypedDictValidator { fields, - extra_validator + extras_validator }); impl Validator for TypedDictValidator { @@ -261,7 +261,7 @@ impl Validator for TypedDictValidator { ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { let py_key = either_str.as_py_string(py); - if let Some(ref validator) = self.extra_validator { + if let Some(ref validator) = self.extras_validator { match validator.validate(py, value, state) { Ok(value) => { output_dict.set_item(py_key, value)?; @@ -314,7 +314,7 @@ impl Validator for TypedDictValidator { self.fields .iter_mut() .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extra_validator { + match &mut self.extras_validator { Some(v) => v.complete(definitions), None => Ok(()), } diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 4ac2d30ab..6c9de2fb2 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -875,7 +875,7 @@ def test_tuple_many_variable(benchmark): @pytest.mark.benchmark(group='tuple-many') def test_tuple_many_positional(benchmark): - v = SchemaValidator({'type': 'tuple-positional', 'items_schema': [], 'extra_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple-positional', 'items_schema': [], 'extras_schema': {'type': 'int'}}) assert v.validate_python(list(range(10))) == tuple(range(10)) benchmark(v.validate_python, list(range(10))) diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index 57b20c4b2..eb4bede97 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -6,7 +6,7 @@ import pytest -from pydantic_core import SchemaSerializer, core_schema +from pydantic_core import SchemaSerializer, SchemaValidator, core_schema on_pypy = platform.python_implementation() == 'PyPy' # pypy doesn't seem to maintain order of `__dict__` @@ -164,3 +164,31 @@ class SubModel(Model): s = SchemaSerializer(schema) assert s.to_python(dc) == {'x': 1, 'x2': 2} assert s.to_json(dc) == b'{"x":1,"x2":2}' + + +@pytest.mark.xfail(reason='dataclasses do not serialize extras') +def test_extra_custom_serializer(): + @dataclasses.dataclass + class Model: + pass + + schema = core_schema.dataclass_schema( + Model, + core_schema.dataclass_args_schema( + 'Model', + [], + extra_behavior='allow', + # extras_schema=core_schema.any_schema( + # serialization=core_schema.plain_serializer_function_ser_schema( + # lambda v: v + ' bam!', + # ) + # ) + ), + [], + ) + s = SchemaSerializer(schema) + v = SchemaValidator(schema) + + m = v.validate_python({'extra': 'extra'}) + + assert s.to_python(m) == {'extra': 'extra bam!'} diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index 836c7499d..9149941e1 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -358,7 +358,7 @@ def f(prefix, value, _info): serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'b'), info_arg=True) ), ], - 'extra_schema': core_schema.any_schema( + 'extras_schema': core_schema.any_schema( serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'extra'), info_arg=True) ), } @@ -396,7 +396,7 @@ def test_tuple_pos_dict_key(): s = SchemaSerializer( core_schema.dict_schema( core_schema.tuple_positional_schema( - [core_schema.int_schema(), core_schema.str_schema()], extra_schema=core_schema.int_schema() + [core_schema.int_schema(), core_schema.str_schema()], extras_schema=core_schema.int_schema() ), core_schema.int_schema(), ) diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index e13cadb70..16e213b0d 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -2,7 +2,7 @@ import json import platform from random import randint -from typing import Any, ClassVar +from typing import Any, ClassVar, Dict try: from functools import cached_property @@ -915,3 +915,27 @@ class InnerModel: s_repr = plain_repr(s) assert 'has_extra:true,root_model:false,name:"InnerModel"' in s_repr assert 'has_extra:false,root_model:false,name:"OuterModel"' in s_repr + + +def test_extra_custom_serializer(): + class Model: + __slots__ = ('__pydantic_extra__', '__dict__') + __pydantic_extra__: Dict[str, Any] + + schema = core_schema.model_schema( + Model, + core_schema.model_fields_schema( + {}, + extra_behavior='allow', + extras_schema=core_schema.any_schema( + serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v + ' bam!') + ), + ), + extra_behavior='allow', + ) + s = SchemaSerializer(schema) + + m = Model() + m.__pydantic_extra__ = {'extra': 'extra'} + + assert s.to_python(m) == {'extra': 'extra bam!'} diff --git a/tests/serializers/test_typed_dict.py b/tests/serializers/test_typed_dict.py index 9bc3c64a9..df507a248 100644 --- a/tests/serializers/test_typed_dict.py +++ b/tests/serializers/test_typed_dict.py @@ -318,3 +318,18 @@ def ser_x(data: Model, v: Any, serializer: core_schema.SerializerFunctionWrapHan ) ) assert json.loads(s.to_json(Model(x=1000))) == {'x': '1_000'} + + +def test_extra_custom_serializer(): + schema = core_schema.typed_dict_schema( + {}, + extra_behavior='allow', + extras_schema=core_schema.any_schema( + serialization=core_schema.plain_serializer_function_ser_schema(lambda v: v + ' bam!') + ), + ) + s = SchemaSerializer(schema) + + m = {'extra': 'extra'} + + assert s.to_python(m) == {'extra': 'extra bam!'} diff --git a/tests/test_json.py b/tests/test_json.py index c29c3ba67..272fb5dcb 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -142,7 +142,7 @@ def test_error_loc(): 'fields': { 'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'list', 'items_schema': {'type': 'int'}}} }, - 'extra_validator': {'type': 'int'}, + 'extras_schema': {'type': 'int'}, 'extra_behavior': 'allow', } ) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 61d0c779b..dbba463e2 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -210,9 +210,9 @@ def test_forbid_extra(): def test_allow_extra_invalid(): - with pytest.raises(SchemaError, match='extra_validator can only be used if extra_behavior=allow'): + with pytest.raises(SchemaError, match='extras_schema can only be used if extra_behavior=allow'): SchemaValidator( - {'type': 'model-fields', 'fields': {}, 'extra_validator': {'type': 'int'}, 'extra_behavior': 'ignore'} + {'type': 'model-fields', 'fields': {}, 'extras_schema': {'type': 'int'}, 'extra_behavior': 'ignore'} ) @@ -358,7 +358,7 @@ def test_validate_assignment_allow_extra_validate(): { 'type': 'model-fields', 'fields': {'field_a': {'type': 'model-field', 'schema': {'type': 'str'}}}, - 'extra_validator': {'type': 'int'}, + 'extras_schema': {'type': 'int'}, 'extra_behavior': 'allow', } ) @@ -1659,19 +1659,19 @@ def test_frozen_field(): ], ) @pytest.mark.parametrize( - 'extra_validator_kw, expected_extra_value', - [({}, '123'), ({'extra_validator': None}, '123'), ({'extra_validator': core_schema.int_schema()}, 123)], - ids=['extra_validator=unset', 'extra_validator=None', 'extra_validator=int'], + 'extras_schema_kw, expected_extra_value', + [({}, '123'), ({'extras_schema': None}, '123'), ({'extras_schema': core_schema.int_schema()}, 123)], + ids=['extras_schema=unset', 'extras_schema=None', 'extras_schema=int'], ) def test_extra_behavior_allow( config: Union[core_schema.CoreConfig, None], schema_extra_behavior_kw: Dict[str, Any], - extra_validator_kw: Dict[str, Any], + extras_schema_kw: Dict[str, Any], expected_extra_value: Any, ): v = SchemaValidator( core_schema.model_fields_schema( - {'f': core_schema.model_field(core_schema.str_schema())}, **schema_extra_behavior_kw, **extra_validator_kw + {'f': core_schema.model_field(core_schema.str_schema())}, **schema_extra_behavior_kw, **extras_schema_kw ), config=config, ) diff --git a/tests/validators/test_tuple.py b/tests/validators/test_tuple.py index 9e365050f..b7f3e9720 100644 --- a/tests/validators/test_tuple.py +++ b/tests/validators/test_tuple.py @@ -264,7 +264,7 @@ def test_positional_empty(py_and_json: PyAndJson): def test_positional_empty_extra(py_and_json: PyAndJson): - v = py_and_json({'type': 'tuple-positional', 'items_schema': [], 'extra_schema': {'type': 'int'}}) + v = py_and_json({'type': 'tuple-positional', 'items_schema': [], 'extras_schema': {'type': 'int'}}) assert v.validate_test([]) == () assert v.validate_python(()) == () assert v.validate_test([1]) == (1,) @@ -408,7 +408,7 @@ def test_tuple_fix_extra(input_value, expected, cache): { 'type': 'tuple-positional', 'items_schema': [{'type': 'int'}, {'type': 'str'}], - 'extra_schema': {'type': 'str'}, + 'extras_schema': {'type': 'str'}, } ) @@ -422,7 +422,7 @@ def test_tuple_fix_extra(input_value, expected, cache): def test_tuple_fix_extra_any(): v = SchemaValidator( - {'type': 'tuple-positional', 'items_schema': [{'type': 'str'}], 'extra_schema': {'type': 'any'}} + {'type': 'tuple-positional', 'items_schema': [{'type': 'str'}], 'extras_schema': {'type': 'any'}} ) assert v.validate_python([b'1']) == ('1',) assert v.validate_python([b'1', 2]) == ('1', 2) diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index b073ad88d..1d3b694f1 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -186,9 +186,9 @@ def test_forbid_extra(): def test_allow_extra_invalid(): - with pytest.raises(SchemaError, match='extra_validator can only be used if extra_behavior=allow'): + with pytest.raises(SchemaError, match='extras_schema can only be used if extra_behavior=allow'): SchemaValidator( - {'type': 'typed-dict', 'fields': {}, 'extra_validator': {'type': 'int'}, 'extra_behavior': 'ignore'} + {'type': 'typed-dict', 'fields': {}, 'extras_schema': {'type': 'int'}, 'extra_behavior': 'ignore'} ) @@ -1089,21 +1089,21 @@ def wrap_function(input_value, validator, info): ], ) @pytest.mark.parametrize( - 'extra_validator_kw, expected_extra_value', - [({}, '123'), ({'extra_validator': None}, '123'), ({'extra_validator': core_schema.int_schema()}, 123)], - ids=['extra_validator=unset', 'extra_validator=None', 'extra_validator=int'], + 'extras_schema_kw, expected_extra_value', + [({}, '123'), ({'extras_schema': None}, '123'), ({'extras_schema': core_schema.int_schema()}, 123)], + ids=['extras_schema=unset', 'extras_schema=None', 'extras_schema=int'], ) def test_extra_behavior_allow( config: Union[core_schema.CoreConfig, None], schema_extra_behavior_kw: Dict[str, Any], - extra_validator_kw: Dict[str, Any], + extras_schema_kw: Dict[str, Any], expected_extra_value: Any, ): v = SchemaValidator( core_schema.typed_dict_schema( {'f': core_schema.typed_dict_field(core_schema.str_schema())}, **schema_extra_behavior_kw, - **extra_validator_kw, + **extras_schema_kw, config=config, ) ) @@ -1173,7 +1173,7 @@ def validate(v, info): schema = core_schema.general_plain_validator_function(validate) schema = core_schema.typed_dict_schema( - {'f': core_schema.typed_dict_field(schema)}, extra_behavior='allow', extra_validator=schema + {'f': core_schema.typed_dict_field(schema)}, extra_behavior='allow', extras_schema=schema ) # If any of the Rust validators don't implement traversal properly, diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 266bd0cda..e0ebd2fb3 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -189,7 +189,7 @@ def test_tuple_positional_omit(): { 'type': 'tuple-positional', 'items_schema': [{'type': 'int'}, {'type': 'int'}], - 'extra_schema': {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}, + 'extras_schema': {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}, } ) assert v.validate_python((1, '2')) == (1, 2) From 882b57fc891d0800a5c26809cfae1c505d75249a Mon Sep 17 00:00:00 2001 From: Yohan Valencia <31091198+yvalencia91@users.noreply.github.com> Date: Wed, 23 Aug 2023 21:29:03 +0200 Subject: [PATCH 004/550] Fix max length error on conlist with type int (#902) --- src/input/return_enums.rs | 56 +++++++++++++++++++++++++---------- tests/validators/test_list.py | 10 +++++-- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index d0f3ba45d..c68566b42 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -105,31 +105,54 @@ struct MaxLengthCheck<'a, INPUT> { max_length: Option, field_type: &'a str, input: &'a INPUT, + known_input_length: usize, } impl<'a, INPUT: Input<'a>> MaxLengthCheck<'a, INPUT> { - fn new(max_length: Option, field_type: &'a str, input: &'a INPUT) -> Self { + fn new(max_length: Option, field_type: &'a str, input: &'a INPUT, known_input_length: usize) -> Self { Self { current_length: 0, max_length, field_type, input, + known_input_length, } } fn incr(&mut self) -> ValResult<'a, ()> { - if let Some(max_length) = self.max_length { - self.current_length += 1; - if self.current_length > max_length { - return Err(ValError::new( - ErrorType::TooLong { - field_type: self.field_type.to_string(), - max_length, - actual_length: self.current_length, - context: None, - }, - self.input, - )); + match self.max_length { + Some(max_length) => { + self.current_length += 1; + if self.current_length > max_length { + let biggest_length = if self.known_input_length > self.current_length { + self.known_input_length + } else { + self.current_length + }; + return Err(ValError::new( + ErrorType::TooLong { + field_type: self.field_type.to_string(), + max_length, + actual_length: biggest_length, + context: None, + }, + self.input, + )); + } + } + None => { + self.current_length += 1; + if self.current_length > self.known_input_length { + return Err(ValError::new( + ErrorType::TooLong { + field_type: self.field_type.to_string(), + max_length: self.known_input_length, + actual_length: self.current_length, + context: None, + }, + self.input, + )); + } } } Ok(()) @@ -315,7 +338,7 @@ impl<'a> GenericIterable<'a> { let capacity = self .generic_len() .unwrap_or_else(|| max_length.unwrap_or(DEFAULT_CAPACITY)); - let max_length_check = MaxLengthCheck::new(max_length, field_type, input); + let max_length_check = MaxLengthCheck::new(max_length, field_type, input, capacity); macro_rules! validate { ($iter:expr) => { @@ -371,7 +394,10 @@ impl<'a> GenericIterable<'a> { field_type: &'static str, max_length: Option, ) -> ValResult<'a, Vec> { - let max_length_check = MaxLengthCheck::new(max_length, field_type, input); + let capacity = self + .generic_len() + .unwrap_or_else(|| max_length.unwrap_or(DEFAULT_CAPACITY)); + let max_length_check = MaxLengthCheck::new(max_length, field_type, input, capacity); match self { GenericIterable::List(collection) => { diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index 4d56a725d..d3cfaa528 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -162,6 +162,12 @@ def test_list_error(input_value, index): infinite_generator(), Err('List should have at most 44 items after validation, not 45 [type=too_long,'), ), + ( + {'max_length': 4, 'items_schema': {'type': 'int'}}, + [0, 1, 2, 3, 4, 5, 6, 7, 8], + Err('List should have at most 4 items after validation, not 9 [type=too_long,'), + ), + ({}, infinite_generator(), Err('List should have at most 10 items after validation, not 11 [type=too_long,')), ], ) def test_list_length_constraints(kwargs: Dict[str, Any], input_value, expected): @@ -391,9 +397,9 @@ def f(v: int) -> int: { 'type': 'too_long', 'loc': (), - 'msg': 'List should have at most 10 items after validation, not 11', + 'msg': 'List should have at most 10 items after validation, not 15', 'input': data, - 'ctx': {'field_type': 'List', 'max_length': 10, 'actual_length': 11}, + 'ctx': {'field_type': 'List', 'max_length': 10, 'actual_length': 15}, } ) From 34fbd845c120869c732d84b9cc5e2797be18bb92 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 24 Aug 2023 22:30:45 +0100 Subject: [PATCH 005/550] snipe off some unsafe code (#922) --- src/errors/validation_exception.rs | 40 +++++---- src/input/input_python.rs | 24 +++--- src/input/return_enums.rs | 134 ++++++++++------------------- src/serializers/infer.rs | 23 +++-- src/serializers/ob_type.rs | 57 ++++++------ 5 files changed, 123 insertions(+), 155 deletions(-) diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index daec6b910..8963c1d81 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -3,11 +3,10 @@ use std::fmt::{Display, Write}; use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; -use pyo3::ffi::Py_ssize_t; +use pyo3::intern; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; -use pyo3::{ffi, intern}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -173,22 +172,27 @@ impl ValidationError { #[pyo3(signature = (*, include_url = true, include_context = true))] pub fn errors(&self, py: Python, include_url: bool, include_context: bool) -> PyResult> { let url_prefix = get_url_prefix(py, include_url); - // taken approximately from the pyo3, but modified to return the error during iteration - // https://github.com/PyO3/pyo3/blob/a3edbf4fcd595f0e234c87d4705eb600a9779130/src/types/list.rs#L27-L55 - unsafe { - let ptr = ffi::PyList_New(self.line_errors.len() as Py_ssize_t); - - // We create the `Py` pointer here for two reasons: - // - panics if the ptr is null - // - its Drop cleans up the list if user code or the asserts panic. - let list: Py = Py::from_owned_ptr(py, ptr); - - for (index, line_error) in (0_isize..).zip(&self.line_errors) { - let item = line_error.as_dict(py, url_prefix, include_context, &self.error_mode)?; - ffi::PyList_SET_ITEM(ptr, index, item.into_ptr()); - } - - Ok(list) + let mut iteration_error = None; + let list = PyList::new( + py, + // PyList::new takes ExactSizeIterator, so if an error occurs during iteration we + // fill the list with None before returning the error; the list will then be thrown + // away safely. + self.line_errors.iter().map(|e| -> PyObject { + if iteration_error.is_some() { + return py.None(); + } + e.as_dict(py, url_prefix, include_context, &self.error_mode) + .unwrap_or_else(|err| { + iteration_error = Some(err); + py.None() + }) + }), + ); + if let Some(err) = iteration_error { + Err(err) + } else { + Ok(list.into()) } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 5ebe9025a..b31ff4361 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -187,6 +187,8 @@ impl<'a> Input<'a> for PyAny { let str = py_str.to_str()?; serde_json::from_str(str).map_err(|e| map_json_err(self, e)) } else if let Ok(py_byte_array) = self.downcast::() { + // Safety: from_slice does not run arbitrary Python code and the GIL is held so the + // bytes array will not be mutated while from_slice is reading it serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e)) } else { Err(ValError::new(ErrorTypeDefaults::JsonType, self)) @@ -235,13 +237,15 @@ impl<'a> Input<'a> for PyAny { }; Ok(str.into()) } else if let Ok(py_byte_array) = self.downcast::() { - // see https://docs.rs/pyo3/latest/pyo3/types/struct.PyByteArray.html#method.as_bytes - // for why this is marked unsafe - let str = match from_utf8(unsafe { py_byte_array.as_bytes() }) { - Ok(s) => s, + // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, + // and we immediately copy the bytes into a new Python string + let s = match from_utf8(unsafe { py_byte_array.as_bytes() }) { + // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the + // final output needs to be Python anyway. + Ok(s) => PyString::new(self.py(), s), Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), }; - Ok(str.into()) + Ok(s.into()) } else { Err(ValError::new(ErrorTypeDefaults::StringType, self)) } @@ -337,9 +341,8 @@ impl<'a> Input<'a> for PyAny { } } fn strict_float(&'a self) -> ValResult> { - if PyFloat::is_exact_type_of(self) { - // Safety: self is PyFloat - Ok(EitherFloat::Py(unsafe { self.downcast_unchecked::() })) + if let Ok(py_float) = self.downcast_exact::() { + Ok(EitherFloat::Py(py_float)) } else if let Ok(float) = self.extract::() { // bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) { @@ -353,9 +356,8 @@ impl<'a> Input<'a> for PyAny { } fn lax_float(&'a self) -> ValResult> { - if PyFloat::is_exact_type_of(self) { - // Safety: self is PyFloat - Ok(EitherFloat::Py(unsafe { self.downcast_unchecked::() })) + if let Ok(py_float) = self.downcast_exact() { + Ok(EitherFloat::Py(py_float)) } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { str_as_float(self, &cow_str) } else if let Ok(float) = self.extract::() { diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index c68566b42..f1cfb8543 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -224,12 +224,11 @@ impl BuildSet for &PySet { impl BuildSet for &PyFrozenSet { fn build_add(&self, item: PyObject) -> PyResult<()> { - unsafe { - py_error_on_minusone( - self.py(), - ffi::PySet_Add(self.as_ptr(), item.to_object(self.py()).as_ptr()), - ) - } + py_error_on_minusone(self.py(), unsafe { + // Safety: self.as_ptr() the _only_ pointer to the `frozenset`, and it's allowed + // to mutate this via the C API when nothing else can refer to it. + ffi::PySet_Add(self.as_ptr(), item.to_object(self.py()).as_ptr()) + }) } fn build_len(&self) -> usize { @@ -492,57 +491,32 @@ impl<'py> Iterator for MappingGenericIterator<'py> { type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { - let item = match self.iter.next() { - Some(Err(e)) => return Some(Err(mapping_err(e, self.iter.py(), self.input))), - Some(Ok(item)) => item, - None => return None, - }; - let tuple: &PyTuple = match item.downcast() { - Ok(tuple) => tuple, - Err(_) => { - return Some(Err(ValError::new( + Some(match self.iter.next()? { + Ok(item) => item.extract().map_err(|_| { + ValError::new( ErrorType::MappingType { error: MAPPING_TUPLE_ERROR.into(), context: None, }, self.input, - ))) - } - }; - if tuple.len() != 2 { - return Some(Err(ValError::new( - ErrorType::MappingType { - error: MAPPING_TUPLE_ERROR.into(), - context: None, - }, - self.input, - ))); - }; - #[cfg(PyPy)] - let key = tuple.get_item(0).unwrap(); - #[cfg(PyPy)] - let value = tuple.get_item(1).unwrap(); - #[cfg(not(PyPy))] - let key = unsafe { tuple.get_item_unchecked(0) }; - #[cfg(not(PyPy))] - let value = unsafe { tuple.get_item_unchecked(1) }; - Some(Ok((key, value))) + ) + }), + Err(e) => Err(mapping_err(e, self.iter.py(), self.input)), + }) } - // size_hint is omitted as it isn't needed } pub struct AttributesGenericIterator<'py> { object: &'py PyAny, - attributes: &'py PyList, - index: usize, + // PyO3 should export this type upstream + attributes_iterator: <&'py PyList as IntoIterator>::IntoIter, } impl<'py> AttributesGenericIterator<'py> { pub fn new(py_any: &'py PyAny) -> ValResult<'py, Self> { Ok(Self { object: py_any, - attributes: py_any.dir(), - index: 0, + attributes_iterator: py_any.dir().into_iter(), }) } } @@ -553,37 +527,31 @@ impl<'py> Iterator for AttributesGenericIterator<'py> { fn next(&mut self) -> Option { // loop until we find an attribute who's name does not start with underscore, // or we get to the end of the list of attributes - while self.index < self.attributes.len() { - #[cfg(PyPy)] - let name: &PyAny = self.attributes.get_item(self.index).unwrap(); - #[cfg(not(PyPy))] - let name: &PyAny = unsafe { self.attributes.get_item_unchecked(self.index) }; - self.index += 1; - // from benchmarks this is 14x faster than using the python `startswith` method - let name_cow = match name.downcast::() { - Ok(name) => name.to_string_lossy(), - Err(e) => return Some(Err(e.into())), - }; - if !name_cow.as_ref().starts_with('_') { - // getattr is most likely to fail due to an exception in a @property, skip - if let Ok(attr) = self.object.getattr(name_cow.as_ref()) { - // we don't want bound methods to be included, is there a better way to check? - // ref https://stackoverflow.com/a/18955425/949890 - let is_bound = matches!(attr.hasattr(intern!(attr.py(), "__self__")), Ok(true)); - // the PyFunction::is_type_of(attr) catches `staticmethod`, but also any other function, - // I think that's better than including static methods in the yielded attributes, - // if someone really wants fields, they can use an explicit field, or a function to modify input - #[cfg(not(PyPy))] - if !is_bound && !PyFunction::is_type_of(attr) { - return Some(Ok((name, attr))); - } - // MASSIVE HACK! PyFunction doesn't exist for PyPy, - // is_instance_of:: crashes with a null pointer, hence this hack, see - // https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635 - #[cfg(PyPy)] - if !is_bound && attr.get_type().to_string() != "" { - return Some(Ok((name, attr))); - } + let name = self.attributes_iterator.next()?; + // from benchmarks this is 14x faster than using the python `startswith` method + let name_cow = match name.downcast::() { + Ok(name) => name.to_string_lossy(), + Err(e) => return Some(Err(e.into())), + }; + if !name_cow.as_ref().starts_with('_') { + // getattr is most likely to fail due to an exception in a @property, skip + if let Ok(attr) = self.object.getattr(name_cow.as_ref()) { + // we don't want bound methods to be included, is there a better way to check? + // ref https://stackoverflow.com/a/18955425/949890 + let is_bound = matches!(attr.hasattr(intern!(attr.py(), "__self__")), Ok(true)); + // the PyFunction::is_type_of(attr) catches `staticmethod`, but also any other function, + // I think that's better than including static methods in the yielded attributes, + // if someone really wants fields, they can use an explicit field, or a function to modify input + #[cfg(not(PyPy))] + if !is_bound && !PyFunction::is_type_of(attr) { + return Some(Ok((name, attr))); + } + // MASSIVE HACK! PyFunction doesn't exist for PyPy, + // is_instance_of:: crashes with a null pointer, hence this hack, see + // https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635 + #[cfg(PyPy)] + if !is_bound && attr.get_type().to_string() != "" { + return Some(Ok((name, attr))); } } } @@ -621,12 +589,7 @@ pub enum GenericIterator { impl From for GenericIterator { fn from(array: JsonArray) -> Self { - let length = array.len(); - let json_iter = GenericJsonIterator { - array, - length, - index: 0, - }; + let json_iter = GenericJsonIterator { array, index: 0 }; Self::JsonArray(json_iter) } } @@ -674,14 +637,15 @@ impl GenericPyIterator { #[derive(Debug, Clone)] pub struct GenericJsonIterator { array: JsonArray, - length: usize, index: usize, } impl GenericJsonIterator { pub fn next(&mut self, _py: Python) -> PyResult> { - if self.index < self.length { - let next = unsafe { self.array.get_unchecked(self.index) }; + if self.index < self.array.len() { + // panic here is impossible due to bounds check above; compiler should be + // able to optimize it away even + let next = &self.array[self.index]; let a = (next, self.index); self.index += 1; Ok(Some(a)) @@ -940,13 +904,7 @@ impl<'a> EitherFloat<'a> { pub fn as_f64(self) -> f64 { match self { EitherFloat::F64(f) => f, - - EitherFloat::Py(f) => { - { - // Safety: known to be a python float - unsafe { ffi::PyFloat_AS_DOUBLE(f.as_ptr()) } - } - } + EitherFloat::Py(f) => f.value(), } } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 91418bb6e..265967c57 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -133,8 +133,8 @@ pub(crate) fn infer_to_python_known( .map(|s| s.into_py(py))?, ObType::Bytearray => { let py_byte_array: &PyByteArray = value.downcast()?; - // see https://docs.rs/pyo3/latest/pyo3/types/struct.PyByteArray.html#method.as_bytes - // for why this is marked unsafe + // Safety: the GIL is held while bytes_to_string is running; it doesn't run + // arbitrary Python code, so py_byte_array cannot be mutated. let bytes = unsafe { py_byte_array.as_bytes() }; extra .config @@ -428,8 +428,12 @@ pub(crate) fn infer_serialize_known( } ObType::Bytearray => { let py_byte_array: &PyByteArray = value.downcast().map_err(py_err_se_err)?; - let bytes = unsafe { py_byte_array.as_bytes() }; - extra.config.bytes_mode.serialize_bytes(bytes, serializer) + // Safety: the GIL is held while serialize_bytes is running; it doesn't run + // arbitrary Python code, so py_byte_array cannot be mutated. + extra + .config + .bytes_mode + .serialize_bytes(unsafe { py_byte_array.as_bytes() }, serializer) } ObType::Dict => serialize_dict!(value.downcast::().map_err(py_err_se_err)?), ObType::List => serialize_seq_filter!(PyList), @@ -581,8 +585,15 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra: .bytes_to_string(key.py(), key.downcast::()?.as_bytes()), ObType::Bytearray => { let py_byte_array: &PyByteArray = key.downcast()?; - let bytes = unsafe { py_byte_array.as_bytes() }; - extra.config.bytes_mode.bytes_to_string(key.py(), bytes) + // Safety: the GIL is held while serialize_bytes is running; it doesn't run + // arbitrary Python code, so py_byte_array cannot be mutated during the call. + // + // We copy the bytes into a new buffer immediately afterwards + extra + .config + .bytes_mode + .bytes_to_string(key.py(), unsafe { py_byte_array.as_bytes() }) + .map(|cow| Cow::Owned(cow.into_owned())) } ObType::Datetime => { let py_dt: &PyDateTime = key.downcast()?; diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 70afcae13..fc491f618 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -1,9 +1,8 @@ -use pyo3::ffi::PyTypeObject; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{ PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList, - PySet, PyString, PyTime, PyTuple, + PySet, PyString, PyTime, PyTuple, PyType, }; use pyo3::{intern, AsPyPointer, PyTypeInfo}; @@ -100,7 +99,7 @@ impl ObTypeLookup { } pub fn is_type(&self, value: &PyAny, expected_ob_type: ObType) -> IsType { - match self.ob_type_is_expected(Some(value), value.get_type_ptr(), &expected_ob_type) { + match self.ob_type_is_expected(Some(value), value.get_type(), &expected_ob_type) { IsType::False => { if expected_ob_type == self.fallback_isinstance(value) { IsType::Subclass @@ -112,12 +111,8 @@ impl ObTypeLookup { } } - fn ob_type_is_expected( - &self, - op_value: Option<&PyAny>, - type_ptr: *mut PyTypeObject, - expected_ob_type: &ObType, - ) -> IsType { + fn ob_type_is_expected(&self, op_value: Option<&PyAny>, py_type: &PyType, expected_ob_type: &ObType) -> IsType { + let type_ptr = py_type.as_ptr(); let ob_type = type_ptr as usize; let ans = match expected_ob_type { ObType::None => self.none == ob_type, @@ -168,28 +163,26 @@ impl ObTypeLookup { // this allows for subtypes of the supported class types, // if we didn't successfully confirm the type, we try again with the next base type pointer provided // it's not null - let base_type_ptr = unsafe { (*type_ptr).tp_base }; - if base_type_ptr.is_null() { - IsType::False - } else { - // as bellow, we don't want to tests for dataclass etc. again, so we pass None as op_value - match self.ob_type_is_expected(None, base_type_ptr, expected_ob_type) { + match get_base_type(py_type) { + // as below, we don't want to tests for dataclass etc. again, so we pass None as op_value + Some(base_type) => match self.ob_type_is_expected(None, base_type, expected_ob_type) { IsType::False => IsType::False, _ => IsType::Subclass, - } + }, + None => IsType::False, } } } pub fn get_type(&self, value: &PyAny) -> ObType { - match self.lookup_by_ob_type(Some(value), value.get_type_ptr()) { + match self.lookup_by_ob_type(Some(value), value.get_type()) { ObType::Unknown => self.fallback_isinstance(value), ob_type => ob_type, } } - fn lookup_by_ob_type(&self, op_value: Option<&PyAny>, type_ptr: *mut PyTypeObject) -> ObType { - let ob_type = type_ptr as usize; + fn lookup_by_ob_type(&self, op_value: Option<&PyAny>, py_type: &PyType) -> ObType { + let ob_type = py_type.as_ptr() as usize; // this should be pretty fast, but still order is a bit important, so the most common types should come first // thus we don't follow the order of ObType if ob_type == self.none { @@ -246,7 +239,7 @@ impl ObTypeLookup { ObType::PydanticSerializable } else if is_dataclass(op_value) { ObType::Dataclass - } else if self.is_enum(op_value, type_ptr) { + } else if self.is_enum(op_value, py_type) { ObType::Enum } else if ob_type == self.generator_object.as_ptr() as usize || is_generator(op_value) { ObType::Generator @@ -255,25 +248,19 @@ impl ObTypeLookup { } else { // this allows for subtypes of the supported class types, // if `ob_type` didn't match any member of self, we try again with the next base type pointer - let base_type_ptr = unsafe { (*type_ptr).tp_base }; - if base_type_ptr.is_null() { - ObType::Unknown - } else { + match get_base_type(py_type) { // we don't want to tests for dataclass etc. again, so we pass None as op_value - self.lookup_by_ob_type(None, base_type_ptr) + Some(base_type) => self.lookup_by_ob_type(None, base_type), + None => ObType::Unknown, } } } - fn is_enum(&self, op_value: Option<&PyAny>, type_ptr: *mut PyTypeObject) -> bool { + fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool { // only test on the type itself, not base types if op_value.is_some() { - // see https://github.com/PyO3/pyo3/issues/2905 for details - #[cfg(all(PyPy, not(Py_3_9)))] - let meta_type = unsafe { (*type_ptr).ob_type }; - #[cfg(any(not(PyPy), Py_3_9))] - let meta_type = unsafe { (*type_ptr).ob_base.ob_base.ob_type }; - meta_type as usize == self.enum_object.as_ptr() as usize + let meta_type = py_type.get_type(); + meta_type.is(&self.enum_object) } else { false } @@ -434,3 +421,9 @@ impl PartialEq for ObType { } } } + +fn get_base_type(py_type: &PyType) -> Option<&PyType> { + let base_type_ptr = unsafe { (*py_type.as_type_ptr()).tp_base }; + // Safety: `base_type_ptr` must be a valid pointer to a Python type object, or null. + unsafe { py_type.py().from_borrowed_ptr_or_opt(base_type_ptr.cast()) } +} From b9f45f9debf1c7f5e49dd09af63f1b3c4f54475b Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 28 Aug 2023 12:03:16 -0600 Subject: [PATCH 006/550] Make round trip exclude computed fields (#934) --- src/serializers/computed_fields.rs | 8 ++++++++ tests/serializers/test_model.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 572a35dc4..8a1f041ae 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -48,6 +48,10 @@ impl ComputedFields { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult<()> { + if extra.round_trip { + // Do not serialize computed fields + return Ok(()); + } for computed_fields in &self.0 { computed_fields.to_python(model, output_dict, filter, include, exclude, extra)?; } @@ -63,6 +67,10 @@ impl ComputedFields { exclude: Option<&PyAny>, extra: &Extra, ) -> Result<(), S::Error> { + if extra.round_trip { + // Do not serialize computed fields + return Ok(()); + } for computed_field in &self.0 { let property_name_py = computed_field.property_name_py.as_ref(model.py()); if let Some((next_include, next_exclude)) = filter diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 16e213b0d..32ecd3c1a 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -570,6 +570,9 @@ def area(self) -> bytes: assert s.to_python(Model(width=3, height=4), mode='json') == {'width': 3, 'height': 4, 'area': '12'} assert s.to_json(Model(width=3, height=4)) == b'{"width":3,"height":4,"area":"12"}' + assert s.to_python(Model(width=3, height=4), round_trip=True) == {'width': 3, 'height': 4} + assert s.to_json(Model(width=3, height=4), round_trip=True) == b'{"width":3,"height":4}' + def test_property_alias(): @dataclasses.dataclass From a23db506c16282ac7a5c0b019a7ebf413463ad15 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 29 Aug 2023 01:04:20 -0500 Subject: [PATCH 007/550] Update version to 2.7.0 (#936) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 25119d19c..d45b03f51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.6.3" +version = "2.7.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index b16624acc..ce32d0cc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.6.3" +version = "2.7.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 42d48fefd8a14e1553c048b8a145eeff46f98a12 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:29:47 +0100 Subject: [PATCH 008/550] Bump base64 from 0.21.2 to 0.21.3 (#933) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d45b03f51..d26e8006c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,9 +31,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" [[package]] name = "bitflags" diff --git a/Cargo.toml b/Cargo.toml index ce32d0cc5..9d020750b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ ahash = "0.8.0" url = "2.3.1" # idna is already required by url, added here to be explicit idna = "0.4.0" -base64 = "0.21.2" +base64 = "0.21.3" num-bigint = "0.4.3" python3-dll-a = "0.2.7" uuid = "1.4.1" From 39b4a7815fac155a9d64a9bcb80fb7149ec5b98a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:30:36 +0100 Subject: [PATCH 009/550] Bump url from 2.4.0 to 2.4.1 (#932) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d26e8006c..36aa60b42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -542,9 +542,9 @@ checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", diff --git a/Cargo.toml b/Cargo.toml index 9d020750b..f084450dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "1.0.185", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.0" ahash = "0.8.0" -url = "2.3.1" +url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" base64 = "0.21.3" From a52d38b0a054f847c736a977658f285910e53253 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:33:03 +0100 Subject: [PATCH 010/550] Bump regex from 1.9.3 to 1.9.4 (#930) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36aa60b42..0845c7116 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -355,9 +355,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" dependencies = [ "aho-corasick", "memchr", @@ -367,9 +367,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" dependencies = [ "aho-corasick", "memchr", @@ -378,9 +378,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "rustversion" diff --git a/Cargo.toml b/Cargo.toml index f084450dd..296258def 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } -regex = "1.9.3" +regex = "1.9.4" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" serde_json = {version = "1.0.105", features = ["arbitrary_precision", "preserve_order"]} From 2dfcd6b8be45a71a7584c87742b1288fe123255d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:33:28 +0100 Subject: [PATCH 011/550] Bump ruff from 0.0.285 to 0.0.286 (#928) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 94c5c0303..d575ef749 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 griffe==0.34.0 pyright==1.1.323 -ruff==0.0.285 +ruff==0.0.286 mypy==1.5.1 From cf90503071226c7975bea448eceb34745da40203 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 08:34:02 +0100 Subject: [PATCH 012/550] Bump griffe from 0.34.0 to 0.35.2 (#927) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index d575ef749..5a158d339 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 -griffe==0.34.0 +griffe==0.35.2 pyright==1.1.323 ruff==0.0.286 mypy==1.5.1 From 53385d3657e5e16ffcb5d7b3e1457bd86d1c8a57 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 07:40:54 +0000 Subject: [PATCH 013/550] Bump pyright from 1.1.323 to 1.1.324 (#926) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 5a158d339..fc4c55b3c 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 griffe==0.35.2 -pyright==1.1.323 +pyright==1.1.324 ruff==0.0.286 mypy==1.5.1 From 18425b00d5961840d78434bf2fdc6a9e5f4490e4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 07:41:21 +0000 Subject: [PATCH 014/550] Bump num-bigint from 0.4.3 to 0.4.4 (#931) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0845c7116..5cb3573b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -168,9 +168,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", @@ -189,9 +189,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" dependencies = [ "autocfg", ] diff --git a/Cargo.toml b/Cargo.toml index 296258def..e7a533605 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,7 @@ url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" base64 = "0.21.3" -num-bigint = "0.4.3" +num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.4.1" From 01a1523afafd8ab5c9073633905c3792c7d06483 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 29 Aug 2023 07:41:50 +0000 Subject: [PATCH 015/550] Bump serde from 1.0.185 to 1.0.188 (#929) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5cb3573b2..87a889e43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -402,18 +402,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be9b6f69f1dfd54c3b568ffa45c310d6973a5e5148fd40cf515acaf38cf5bc31" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc59dfdcbad1437773485e0367fea4b090a2e0a16d9ffc46af47764536a298ec" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index e7a533605..75b2bcb34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" serde_json = {version = "1.0.105", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.185", features = ["derive"] } +serde = { version = "1.0.188", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.0" ahash = "0.8.0" From 9ec1036b4583d5ada15f7ac3e32fcc04a981746a Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Thu, 31 Aug 2023 15:37:01 +0200 Subject: [PATCH 016/550] Fix typo in `MultiHostUrl.build` docstring (#938) --- python/pydantic_core/_pydantic_core.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 5aae6df8a..9b24e7076 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -631,7 +631,7 @@ class MultiHostUrl(SupportsAllComparisons): fragment: Optional[str] = None, ) -> Self: """ - Build a new `MultiHostUul` instance from its component parts. + Build a new `MultiHostUrl` instance from its component parts. This method takes either `hosts` - a list of `MultiHostHost` typed dicts, or the individual components `username`, `password`, `host` and `port`. From 04a9135a695212665017175c96f1bc7dc5007349 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:42:43 +0100 Subject: [PATCH 017/550] Bump regex from 1.9.4 to 1.9.5 (#944) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 87a889e43..4e38d8cf7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,9 +153,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memoffset" @@ -355,9 +355,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -367,9 +367,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 75b2bcb34..3f5b25e7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } -regex = "1.9.4" +regex = "1.9.5" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" serde_json = {version = "1.0.105", features = ["arbitrary_precision", "preserve_order"]} From fffb1d853ccdbacbf1dee84be8c76a67178ae5e1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:43:20 +0100 Subject: [PATCH 018/550] Bump pytest from 7.4.0 to 7.4.1 (#939) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index c1d0e8a70..c8469082f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,7 +3,7 @@ dirty-equals==0.6.0 hypothesis==6.79.4 # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.0 +pytest==7.4.1 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.1.0; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-examples==0.0.10 From 7dc8946e05d73ddd68e71ee34cb6a94fa470c65f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:43:34 +0100 Subject: [PATCH 019/550] Bump ruff from 0.0.286 to 0.0.287 (#942) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index fc4c55b3c..5536a4cb8 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 griffe==0.35.2 pyright==1.1.324 -ruff==0.0.286 +ruff==0.0.287 mypy==1.5.1 From 8a0bc71f3ed13fd93eeaa10bf4acd0c206c13fab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:43:53 +0100 Subject: [PATCH 020/550] Bump griffe from 0.35.2 to 0.36.0 (#940) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 5536a4cb8..1b125e5f6 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 -griffe==0.35.2 +griffe==0.36.0 pyright==1.1.324 ruff==0.0.287 mypy==1.5.1 From e09112a28aa6dae3e8a41ac3d1a4e9ebc10d5863 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 08:50:55 +0000 Subject: [PATCH 021/550] Bump pyright from 1.1.324 to 1.1.325 (#941) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 1b125e5f6..2204967c1 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 griffe==0.36.0 -pyright==1.1.324 +pyright==1.1.325 ruff==0.0.287 mypy==1.5.1 From 6a139753af85fc7bb6b34f26c1328994506f94ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 10:15:42 +0100 Subject: [PATCH 022/550] Bump griffe from 0.35.2 to 0.36.1 (#946) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 2204967c1..6faa2c2d2 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 -griffe==0.36.0 +griffe==0.36.1 pyright==1.1.325 ruff==0.0.287 mypy==1.5.1 From 2d9df49bc2c99b076cfbe87b0101f440abbd8519 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 7 Sep 2023 08:16:06 -0500 Subject: [PATCH 023/550] Update pytz (#949) --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index c8469082f..aad4fd52c 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -11,6 +11,6 @@ pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 pytest-timeout==2.1.0 -pytz==2023.3 +pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' From 6769140dfb4cbe817567344308d3d492ef0e677e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 7 Sep 2023 08:24:08 -0500 Subject: [PATCH 024/550] Fix parsing int from large decimals (#948) Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- src/input/input_abstract.rs | 12 ++++---- src/input/input_json.rs | 14 ++++----- src/input/input_python.rs | 24 ++++++++++----- src/input/shared.rs | 14 +++++++++ src/validators/decimal.rs | 58 +++++++++++++++++++++--------------- tests/validators/test_int.py | 21 +++++++++++++ 6 files changed, 97 insertions(+), 46 deletions(-) diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 80ebc059c..3e780237b 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -152,17 +152,17 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_float() } - fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> { if strict { - self.strict_decimal(decimal_type) + self.strict_decimal(py) } else { - self.lax_decimal(decimal_type) + self.lax_decimal(py) } } - fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny>; + fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny>; #[cfg_attr(has_no_coverage, no_coverage)] - fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { - self.strict_decimal(decimal_type) + fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { + self.strict_decimal(py) } fn validate_dict(&'a self, strict: bool) -> ValResult> { diff --git a/src/input/input_json.rs b/src/input/input_json.rs index c2da56703..86079da09 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString, PyType}; +use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; @@ -178,13 +178,12 @@ impl<'a> Input<'a> for JsonInput { } } - fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { - let py = decimal_type.py(); + fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { match self { - JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, decimal_type), + JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => { - create_decimal(self.to_object(py).into_ref(py), self, decimal_type) + create_decimal(self.to_object(py).into_ref(py), self, py) } _ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), } @@ -439,9 +438,8 @@ impl<'a> Input<'a> for String { str_as_float(self, self) } - fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { - let py = decimal_type.py(); - create_decimal(self.to_object(py).into_ref(py), self, decimal_type) + fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { + create_decimal(self.to_object(py).into_ref(py), self, py) } #[cfg_attr(has_no_coverage, no_coverage)] diff --git a/src/input/input_python.rs b/src/input/input_python.rs index b31ff4361..3fbf20240 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -13,7 +13,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; -use crate::validators::decimal::create_decimal; +use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; use super::datetime::{ @@ -21,7 +21,7 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ py_string_str, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, @@ -324,6 +324,10 @@ impl<'a> Input<'a> for PyAny { } else if PyInt::is_type_of(self) { // force to an int to upcast to a pure python int to maintain current behaviour EitherInt::upcast(self) + } else if PyFloat::is_exact_type_of(self) { + float_as_int(self, self.extract::()?) + } else if let Ok(decimal) = self.strict_decimal(self.py()) { + decimal_as_int(self.py(), self, decimal) } else if let Ok(float) = self.extract::() { float_as_int(self, float) } else { @@ -367,7 +371,9 @@ impl<'a> Input<'a> for PyAny { } } - fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { + let decimal_type_obj: Py = get_decimal_type(py); + let decimal_type = decimal_type_obj.as_ref(py); // Fast path for existing decimal objects if self.is_exact_instance(decimal_type) { return Ok(self); @@ -375,7 +381,7 @@ impl<'a> Input<'a> for PyAny { // Try subclasses of decimals, they will be upcast to Decimal if self.is_instance(decimal_type)? { - return create_decimal(self, self, decimal_type); + return create_decimal(self, self, py); } Err(ValError::new( @@ -387,7 +393,9 @@ impl<'a> Input<'a> for PyAny { )) } - fn lax_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> { + fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { + let decimal_type_obj: Py = get_decimal_type(py); + let decimal_type = decimal_type_obj.as_ref(py); // Fast path for existing decimal objects if self.is_exact_instance(decimal_type) { return Ok(self); @@ -395,12 +403,12 @@ impl<'a> Input<'a> for PyAny { if self.is_instance_of::() || (self.is_instance_of::() && !self.is_instance_of::()) { // checking isinstance for str / int / bool is fast compared to decimal / float - create_decimal(self, self, decimal_type) + create_decimal(self, self, py) } else if self.is_instance(decimal_type)? { // upcast subclasses to decimal - return create_decimal(self, self, decimal_type); + return create_decimal(self, self, py); } else if self.is_instance_of::() { - create_decimal(self.str()?, self, decimal_type) + create_decimal(self.str()?, self, py) } else { Err(ValError::new(ErrorTypeDefaults::DecimalType, self)) } diff --git a/src/input/shared.rs b/src/input/shared.rs index bed673531..d8733bd31 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,4 +1,5 @@ use num_bigint::BigInt; +use pyo3::{intern, PyAny, Python}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::EitherInt; @@ -136,3 +137,16 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a, Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)) } } + +pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<'a, EitherInt<'a>> { + if !decimal.call_method0(intern!(py, "is_finite"))?.extract::()? { + return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); + } + let (numerator, denominator) = decimal + .call_method0(intern!(py, "as_integer_ratio"))? + .extract::<(&PyAny, &PyAny)>()?; + if denominator.extract::().map_or(true, |d| d != 1) { + return Err(ValError::new(ErrorTypeDefaults::IntFromFloat, input)); + } + Ok(EitherInt::Py(numerator)) +} diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index faacf4dd3..2564e096a 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -1,4 +1,5 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType}; use pyo3::{intern, AsPyPointer}; use pyo3::{prelude::*, PyTypeInfo}; @@ -13,6 +14,21 @@ use crate::tools::SchemaDict; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +static DECIMAL_TYPE: GILOnceCell> = GILOnceCell::new(); + +pub fn get_decimal_type(py: Python) -> Py { + DECIMAL_TYPE + .get_or_init(py, || { + py.import("decimal") + .and_then(|decimal_module| decimal_module.getattr("Decimal")) + .unwrap() + .extract::<&PyType>() + .unwrap() + .into() + }) + .clone() +} + #[derive(Debug, Clone)] pub struct DecimalValidator { strict: bool, @@ -25,7 +41,6 @@ pub struct DecimalValidator { gt: Option>, max_digits: Option, decimal_places: Option, - decimal_type: Py, } impl BuildValidator for DecimalValidator { @@ -55,10 +70,6 @@ impl BuildValidator for DecimalValidator { ge: schema.get_as(intern!(py, "ge"))?, gt: schema.get_as(intern!(py, "gt"))?, max_digits, - decimal_type: py - .import(intern!(py, "decimal"))? - .getattr(intern!(py, "Decimal"))? - .extract()?, } .into()) } @@ -69,8 +80,7 @@ impl_py_gc_traverse!(DecimalValidator { le, lt, ge, - gt, - decimal_type + gt }); impl Validator for DecimalValidator { @@ -80,11 +90,7 @@ impl Validator for DecimalValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let decimal = input.validate_decimal( - state.strict_or(self.strict), - // Safety: self and py both outlive this call - unsafe { py.from_borrowed_ptr(self.decimal_type.as_ptr()) }, - )?; + let decimal = input.validate_decimal(state.strict_or(self.strict), py)?; if !self.allow_inf_nan || self.check_digits { if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? { @@ -244,19 +250,23 @@ impl Validator for DecimalValidator { pub(crate) fn create_decimal<'a>( arg: &'a PyAny, input: &'a impl Input<'a>, - decimal_type: &'a PyType, + py: Python<'a>, ) -> ValResult<'a, &'a PyAny> { - decimal_type.call1((arg,)).map_err(|e| { - let decimal_exception = match arg - .py() - .import("decimal") - .and_then(|decimal_module| decimal_module.getattr("DecimalException")) - { - Ok(decimal_exception) => decimal_exception, - Err(e) => return ValError::InternalErr(e), - }; - handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception) - }) + let decimal_type_obj: Py = get_decimal_type(py); + decimal_type_obj + .call1(py, (arg,)) + .map_err(|e| { + let decimal_exception = match arg + .py() + .import("decimal") + .and_then(|decimal_module| decimal_module.getattr("DecimalException")) + { + Ok(decimal_exception) => decimal_exception, + Err(e) => return ValError::InternalErr(e), + }; + handle_decimal_new_error(arg.py(), input.as_error_value(), e, decimal_exception) + }) + .map(|v| v.into_ref(py)) } fn handle_decimal_new_error<'a>( diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 44a806118..43cd5bacb 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -58,7 +58,9 @@ def test_int_py_and_json(py_and_json: PyAndJson, input_value, expected): 'input_value,expected', [ (Decimal('1'), 1), + (Decimal('1' + '0' * 1_000), int('1' + '0' * 1_000)), # a large decimal (Decimal('1.0'), 1), + (1.0, 1), (i64_max, i64_max), (str(i64_max), i64_max), (str(i64_max * 2), i64_max * 2), @@ -66,6 +68,14 @@ def test_int_py_and_json(py_and_json: PyAndJson, input_value, expected): (-i64_max + 1, -i64_max + 1), (i64_max * 2, i64_max * 2), (-i64_max * 2, -i64_max * 2), + pytest.param( + 1.00000000001, + Err( + 'Input should be a valid integer, got a number with a fractional part ' + '[type=int_from_float, input_value=1.00000000001, input_type=float]' + ), + id='decimal-remainder', + ), pytest.param( Decimal('1.001'), Err( @@ -437,3 +447,14 @@ def test_int_subclass_constraint() -> None: with pytest.raises(ValidationError, match='Input should be greater than 0'): v.validate_python(IntSubclass(0)) + + +class FloatSubclass(float): + pass + + +def test_float_subclass() -> None: + v = SchemaValidator({'type': 'int'}) + v_lax = v.validate_python(FloatSubclass(1)) + assert v_lax == 1 + assert type(v_lax) == int From f6b14cce22b1539727ed0e0a47d2c680c79c1692 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Fri, 8 Sep 2023 07:00:14 +0100 Subject: [PATCH 025/550] make error "duplicate" cheaper (#950) --- src/errors/line_error.rs | 31 +++++++++++------------ src/errors/validation_exception.rs | 16 ++++++------ src/input/input_json.rs | 7 +++--- src/input/parse_json.rs | 14 ++++++----- src/input/return_enums.rs | 9 +++---- src/lazy_index_map.rs | 40 ++++++++++++++---------------- src/validators/function.rs | 2 +- src/validators/generator.rs | 8 +++--- src/validators/json.rs | 2 +- 9 files changed, 62 insertions(+), 67 deletions(-) diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index c8230a27a..e5d3c7bac 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -62,10 +62,10 @@ impl<'a> ValError<'a> { } /// a bit like clone but change the lifetime to match py - pub fn duplicate<'py>(&self, py: Python<'py>) -> ValError<'py> { + pub fn into_owned(self, py: Python<'_>) -> ValError<'_> { match self { - ValError::LineErrors(errors) => errors.iter().map(|e| e.duplicate(py)).collect::>().into(), - ValError::InternalErr(err) => ValError::InternalErr(err.clone_ref(py)), + ValError::LineErrors(errors) => errors.into_iter().map(|e| e.into_owned(py)).collect::>().into(), + ValError::InternalErr(err) => ValError::InternalErr(err), ValError::Omit => ValError::Omit, ValError::UseDefault => ValError::UseDefault, } @@ -129,28 +129,26 @@ impl<'a> ValLineError<'a> { self } - /// a bit like clone but change the lifetime to match py, used by ValError.duplicate above - pub fn duplicate<'py>(&'a self, py: Python<'py>) -> ValLineError<'py> { + /// a bit like clone but change the lifetime to match py, used by ValError.into_owned above + pub fn into_owned(self, py: Python<'_>) -> ValLineError<'_> { ValLineError { - error_type: self.error_type.clone(), - input_value: InputValue::<'py>::from(self.input_value.to_object(py)), - location: self.location.clone(), + error_type: self.error_type, + input_value: match self.input_value { + InputValue::PyAny(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), + InputValue::JsonInput(input) => InputValue::JsonInput(input), + InputValue::String(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), + }, + location: self.location, } } } #[cfg_attr(debug_assertions, derive(Debug))] +#[derive(Clone)] pub enum InputValue<'a> { PyAny(&'a PyAny), - JsonInput(&'a JsonInput), + JsonInput(JsonInput), String(&'a str), - PyObject(PyObject), -} - -impl<'a> From for InputValue<'a> { - fn from(py_object: PyObject) -> Self { - Self::PyObject(py_object) - } } impl<'a> ToPyObject for InputValue<'a> { @@ -159,7 +157,6 @@ impl<'a> ToPyObject for InputValue<'a> { Self::PyAny(input) => input.into_py(py), Self::JsonInput(input) => input.to_object(py), Self::String(input) => input.into_py(py), - Self::PyObject(py_obj) => py_obj.into_py(py), } } } diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 8963c1d81..9dde6551c 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -22,7 +22,7 @@ use super::line_error::ValLineError; use super::location::Location; use super::types::{ErrorMode, ErrorType}; use super::value_exception::PydanticCustomError; -use super::ValError; +use super::{InputValue, ValError}; #[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] #[derive(Clone)] @@ -128,11 +128,11 @@ fn get_url_prefix(py: Python, include_url: bool) -> Option<&str> { } // used to convert a validation error back to ValError for wrap functions -impl<'a> IntoPy> for ValidationError { - fn into_py(self, py: Python) -> ValError<'a> { +impl ValidationError { + pub(crate) fn into_val_error(self, py: Python<'_>) -> ValError<'_> { self.line_errors .into_iter() - .map(|e| e.into_py(py)) + .map(|e| e.into_val_line_error(py)) .collect::>() .into() } @@ -322,13 +322,13 @@ impl<'a> IntoPy for ValLineError<'a> { } } -/// opposite of above, used to extract line errors from a validation error for wrap functions -impl<'a> IntoPy> for PyLineError { - fn into_py(self, _py: Python) -> ValLineError<'a> { +impl PyLineError { + /// Used to extract line errors from a validation error for wrap functions + fn into_val_line_error(self, py: Python<'_>) -> ValLineError<'_> { ValLineError { error_type: self.error_type, location: self.location, - input_value: self.input_value.into(), + input_value: InputValue::PyAny(self.input_value.into_ref(py)), } } } diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 86079da09..d9cb81fe2 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -31,7 +31,8 @@ impl<'a> Input<'a> for JsonInput { } fn as_error_value(&'a self) -> InputValue<'a> { - InputValue::JsonInput(self) + // cloning JsonInput is cheap due to use of Arc + InputValue::JsonInput(self.clone()) } fn is_none(&self) -> bool { @@ -262,7 +263,7 @@ impl<'a> Input<'a> for JsonInput { JsonInput::String(s) => Ok(string_to_vec(s).into()), JsonInput::Object(object) => { // return keys iterator to match python's behavior - let keys: Vec = object.keys().map(|k| JsonInput::String(k.clone())).collect(); + let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect()); Ok(keys.into()) } _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), @@ -550,5 +551,5 @@ impl<'a> Input<'a> for String { } fn string_to_vec(s: &str) -> JsonArray { - s.chars().map(|c| JsonInput::String(c.to_string())).collect() + JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) } diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs index 3bc2d0d46..7603eaf67 100644 --- a/src/input/parse_json.rs +++ b/src/input/parse_json.rs @@ -1,9 +1,11 @@ use std::fmt; +use std::sync::Arc; use num_bigint::BigInt; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; +use smallvec::SmallVec; use crate::lazy_index_map::LazyIndexMap; @@ -20,8 +22,8 @@ pub enum JsonInput { Array(JsonArray), Object(JsonObject), } -pub type JsonArray = Vec; -pub type JsonObject = LazyIndexMap; +pub type JsonArray = Arc>; +pub type JsonObject = Arc>; impl ToPyObject for JsonInput { fn to_object(&self, py: Python<'_>) -> PyObject { @@ -111,13 +113,13 @@ impl<'de> Deserialize<'de> for JsonInput { where V: SeqAccess<'de>, { - let mut vec = Vec::new(); + let mut vec = SmallVec::new(); while let Some(elem) = visitor.next_element()? { vec.push(elem); } - Ok(JsonInput::Array(vec)) + Ok(JsonInput::Array(JsonArray::new(vec))) } fn visit_map(self, mut visitor: V) -> Result @@ -171,9 +173,9 @@ impl<'de> Deserialize<'de> for JsonInput { while let Some((key, value)) = visitor.next_entry()? { values.insert(key, value); } - Ok(JsonInput::Object(values)) + Ok(JsonInput::Object(Arc::new(values))) } - None => Ok(JsonInput::Object(LazyIndexMap::new())), + None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))), } } } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index f1cfb8543..e97fe8b81 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -625,8 +625,8 @@ impl GenericPyIterator { } } - pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny { - self.obj.as_ref(py) + pub fn input_as_error_value<'py>(&self, py: Python<'py>) -> InputValue<'py> { + InputValue::PyAny(self.obj.clone_ref(py).into_ref(py)) } pub fn index(&self) -> usize { @@ -654,9 +654,8 @@ impl GenericJsonIterator { } } - pub fn input<'a>(&'a self, py: Python<'a>) -> &'a PyAny { - let input = JsonInput::Array(self.array.clone()); - input.to_object(py).into_ref(py) + pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { + InputValue::JsonInput(JsonInput::Array(self.array.clone())) } pub fn index(&self) -> usize { diff --git a/src/lazy_index_map.rs b/src/lazy_index_map.rs index 163421de3..c5621f877 100644 --- a/src/lazy_index_map.rs +++ b/src/lazy_index_map.rs @@ -1,32 +1,36 @@ use std::borrow::Borrow; -use std::cell::RefCell; use std::cmp::{Eq, PartialEq}; use std::fmt::Debug; use std::hash::Hash; use std::slice::Iter as SliceIter; +use std::sync::OnceLock; use ahash::AHashMap; +use smallvec::SmallVec; #[derive(Debug, Clone, Default)] pub struct LazyIndexMap { - vec: Vec<(K, V)>, - map: RefCell>>, + vec: SmallVec<[(K, V); 8]>, + map: OnceLock>, } /// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. impl LazyIndexMap where K: Clone + Debug + Eq + Hash, - V: Clone + Debug, + V: Debug, { pub fn new() -> Self { Self { - vec: Vec::new(), - map: RefCell::new(None), + vec: SmallVec::new(), + map: OnceLock::new(), } } pub fn insert(&mut self, key: K, value: V) { + if let Some(map) = self.map.get_mut() { + map.insert(key.clone(), self.vec.len()); + } self.vec.push((key, value)); } @@ -39,22 +43,14 @@ where K: Borrow + PartialEq, Q: Hash + Eq, { - let mut map = self.map.borrow_mut(); - if let Some(map) = map.as_ref() { - map.get(key).map(|&i| &self.vec[i].1) - } else { - let mut new_map = AHashMap::with_capacity(self.vec.len()); - let mut value = None; - // reverse here so the last value is the one that's returned - for (index, (k, v)) in self.vec.iter().enumerate().rev() { - if value.is_none() && k == key { - value = Some(v); - } - new_map.insert(k.clone(), index); - } - *map = Some(new_map); - value - } + let map = self.map.get_or_init(|| { + self.vec + .iter() + .enumerate() + .map(|(index, (key, _))| (key.clone(), index)) + .collect() + }); + map.get(key).map(|&i| &self.vec[i].1) } pub fn keys(&self) -> impl Iterator { diff --git a/src/validators/function.rs b/src/validators/function.rs index 8f9b25d70..fa8a0673b 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -504,7 +504,7 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> } else if let Ok(pydantic_error_type) = err.value(py).extract::() { pydantic_error_type.into_val_error(input) } else if let Ok(validation_error) = err.value(py).extract::() { - validation_error.into_py(py) + validation_error.into_val_error(py) } else { py_err_string!(err.value(py), ValueError, input) } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 698504a83..1047e31bd 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -127,14 +127,14 @@ impl ValidatorIterator { Some(validator) => { if let Some(max_length) = max_length { if index >= max_length { - let val_error = ValError::new( + let val_error = ValError::new_custom_input( ErrorType::TooLong { field_type: "Generator".to_string(), max_length, actual_length: index + 1, context: None, }, - $iter.input(py), + $iter.input_as_error_value(py), ); return Err(ValidationError::from_val_error( py, @@ -153,14 +153,14 @@ impl ValidatorIterator { None => { if let Some(min_length) = min_length { if $iter.index() < min_length { - let val_error = ValError::new( + let val_error = ValError::new_custom_input( ErrorType::TooShort { field_type: "Generator".to_string(), min_length, actual_length: $iter.index(), context: None, }, - $iter.input(py), + $iter.input_as_error_value(py), ); return Err(ValidationError::from_val_error( py, diff --git a/src/validators/json.rs b/src/validators/json.rs index f99ac29ae..5eda007be 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -55,7 +55,7 @@ impl Validator for JsonValidator { match self.validator { Some(ref validator) => match validator.validate(py, &json_value, state) { Ok(v) => Ok(v), - Err(err) => Err(err.duplicate(py)), + Err(err) => Err(err.into_owned(py)), }, None => Ok(json_value.to_object(py)), } From 6e9c6ebf50935236677a1c0999fafc3e949d818e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:12:53 +0200 Subject: [PATCH 026/550] Bump actions/checkout from 3 to 4 (#958) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 32 ++++++++++++++++---------------- .github/workflows/codspeed.yml | 2 +- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d784a79c..0b54ac0d8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install rust nightly uses: dtolnay/rust-toolchain@nightly @@ -79,7 +79,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install rust stable uses: dtolnay/rust-toolchain@stable @@ -116,7 +116,7 @@ jobs: runs-on: ${{ matrix.os }}-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install rust stable uses: dtolnay/rust-toolchain@stable @@ -156,7 +156,7 @@ jobs: - 'pypy3.10' steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v4 with: @@ -181,12 +181,12 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: pydantic/pydantic path: pydantic - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: pydantic-core @@ -221,7 +221,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install rust stable uses: dtolnay/rust-toolchain@stable @@ -269,7 +269,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: install rust nightly uses: dtolnay/rust-toolchain@nightly @@ -288,7 +288,7 @@ jobs: build-wasm-emscripten: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python id: setup-python @@ -351,7 +351,7 @@ jobs: name: build sdist runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: PyO3/maturin-action@v1 with: command: sdist @@ -432,7 +432,7 @@ jobs: runs-on: ${{ matrix.os }}-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v4 @@ -485,7 +485,7 @@ jobs: runs-on: ${{ matrix.os }}-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v4 @@ -559,7 +559,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: get dist artifacts uses: actions/download-artifact@v3 @@ -599,7 +599,7 @@ jobs: distro: alpine_latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: get dist artifacts uses: actions/download-artifact@v3 @@ -646,7 +646,7 @@ jobs: runs-on: ${{ matrix.os }}-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v4 @@ -670,7 +670,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v4 diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 2150ec79b..32bc83a1e 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: From 03cecdd54fd1f3ad770944efe9991499f71042d1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:14:23 +0200 Subject: [PATCH 027/550] Bump pytest from 7.4.1 to 7.4.2 (#956) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index aad4fd52c..6ea2d16cf 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,7 +3,7 @@ dirty-equals==0.6.0 hypothesis==6.79.4 # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.1 +pytest==7.4.2 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.1.0; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-examples==0.0.10 From ba3c502c681951fce325a2537c59616a2f4b8f33 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:14:39 +0200 Subject: [PATCH 028/550] Bump griffe from 0.36.1 to 0.36.2 (#954) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 6faa2c2d2..8e7b5c66c 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.7.0 -griffe==0.36.1 +griffe==0.36.2 pyright==1.1.325 ruff==0.0.287 mypy==1.5.1 From 1c79f979c6826bffca5ffe5c1f0ee22ea8ea4b11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:16:12 +0200 Subject: [PATCH 029/550] Bump serde_json from 1.0.105 to 1.0.106 (#951) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e38d8cf7..ff04e484d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -422,9 +422,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index 3f5b25e7f..0ab2d81bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } regex = "1.9.5" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" -serde_json = {version = "1.0.105", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.106", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.188", features = ["derive"] } speedate = "0.12.0" From 5edddd5060919fa83d03ee33d6b4749e070feb83 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 17 Sep 2023 08:16:43 +0200 Subject: [PATCH 030/550] Bump base64 from 0.21.3 to 0.21.4 (#952) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff04e484d..4c8b1d58f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,9 +31,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "bitflags" diff --git a/Cargo.toml b/Cargo.toml index 0ab2d81bc..3c7e18d4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ ahash = "0.8.0" url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" -base64 = "0.21.3" +base64 = "0.21.4" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.4.1" From 3dfecd70d1c5612742f197b0d436bcf780840f22 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:14:40 +0100 Subject: [PATCH 031/550] update to coverage_attribute feature (#966) --- build.rs | 4 +-- src/build_tools.rs | 2 +- src/input/input_abstract.rs | 30 ++++++++++---------- src/input/input_json.rs | 56 ++++++++++++++++++------------------- src/input/parse_json.rs | 10 +++---- src/lib.rs | 2 +- src/tools.rs | 2 +- 7 files changed, 53 insertions(+), 53 deletions(-) diff --git a/build.rs b/build.rs index d27cc3b63..7f59e1f57 100644 --- a/build.rs +++ b/build.rs @@ -32,8 +32,8 @@ fn generate_self_schema() { fn main() { pyo3_build_config::use_pyo3_cfgs(); - if let Some(true) = version_check::supports_feature("no_coverage") { - println!("cargo:rustc-cfg=has_no_coverage"); + if let Some(true) = version_check::supports_feature("coverage_attribute") { + println!("cargo:rustc-cfg=has_coverage_attribute"); } generate_self_schema(); println!("cargo:rustc-env=PROFILE={}", std::env::var("PROFILE").unwrap()); diff --git a/src/build_tools.rs b/src/build_tools.rs index 0c8abdf6d..d2bc7c2cd 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -67,7 +67,7 @@ impl fmt::Display for SchemaError { } impl Error for SchemaError { - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn source(&self) -> Option<&(dyn Error + 'static)> { None } diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 3e780237b..d799da473 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -80,7 +80,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_str(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_str(&'a self) -> ValResult> { self.strict_str() } @@ -93,7 +93,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_bytes(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_bytes(&'a self) -> ValResult> { self.strict_bytes() } @@ -106,7 +106,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_bool(&self) -> ValResult; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_bool(&self) -> ValResult { self.strict_bool() } @@ -119,7 +119,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_int(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_int(&'a self) -> ValResult> { self.strict_int() } @@ -147,7 +147,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } fn ultra_strict_float(&'a self) -> ValResult>; fn strict_float(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_float(&'a self) -> ValResult> { self.strict_float() } @@ -160,7 +160,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { self.strict_decimal(py) } @@ -173,7 +173,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_dict(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_dict(&'a self) -> ValResult> { self.strict_dict() } @@ -190,7 +190,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_list(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_list(&'a self) -> ValResult> { self.strict_list() } @@ -203,7 +203,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_tuple(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_tuple(&'a self) -> ValResult> { self.strict_tuple() } @@ -216,7 +216,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_set(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_set(&'a self) -> ValResult> { self.strict_set() } @@ -229,7 +229,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_frozenset(&'a self) -> ValResult>; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_frozenset(&'a self) -> ValResult> { self.strict_frozenset() } @@ -246,7 +246,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { } } fn strict_date(&self) -> ValResult; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_date(&self) -> ValResult { self.strict_date() } @@ -266,7 +266,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValResult; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_time( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, @@ -289,7 +289,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValResult; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_datetime( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, @@ -312,7 +312,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, ) -> ValResult; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn lax_timedelta( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, diff --git a/src/input/input_json.rs b/src/input/input_json.rs index d9cb81fe2..d948e2493 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -21,7 +21,7 @@ use super::{ impl<'a> Input<'a> for JsonInput { /// This is required by since JSON object keys are always strings, I don't think it can be called - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn as_loc_item(&self) -> LocItem { match self { JsonInput::Int(i) => (*i).into(), @@ -102,7 +102,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_bytes(&'a self) -> ValResult> { self.validate_bytes(false) } @@ -196,7 +196,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::DictType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_dict(&'a self) -> ValResult> { self.validate_dict(false) } @@ -207,7 +207,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_list(&'a self) -> ValResult> { self.validate_list(false) } @@ -219,7 +219,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::TupleType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_tuple(&'a self) -> ValResult> { self.validate_tuple(false) } @@ -231,7 +231,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::SetType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_set(&'a self) -> ValResult> { self.validate_set(false) } @@ -243,7 +243,7 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)), } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_frozenset(&'a self) -> ValResult> { self.validate_frozenset(false) } @@ -278,7 +278,7 @@ impl<'a> Input<'a> for JsonInput { } // NO custom `lax_date` implementation, if strict_date fails, the validator will fallback to lax_datetime // then check there's no remainder - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_date(&self) -> ValResult { self.validate_date(false) } @@ -365,7 +365,7 @@ impl<'a> Input<'a> for String { InputValue::String(self) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn is_none(&self) -> bool { false } @@ -374,12 +374,12 @@ impl<'a> Input<'a> for String { None } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { let class_name = class_name.to_string(); Err(ValError::new( @@ -405,7 +405,7 @@ impl<'a> Input<'a> for String { fn validate_bytes(&'a self, _strict: bool) -> ValResult> { Ok(self.as_bytes().into()) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_bytes(&'a self) -> ValResult> { self.validate_bytes(false) } @@ -427,11 +427,11 @@ impl<'a> Input<'a> for String { } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn ultra_strict_float(&'a self) -> ValResult> { self.strict_float() } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_float(&'a self) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } @@ -443,47 +443,47 @@ impl<'a> Input<'a> for String { create_decimal(self.to_object(py).into_ref(py), self, py) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_dict(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::DictType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_dict(&'a self) -> ValResult> { self.validate_dict(false) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_list(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::ListType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_list(&'a self) -> ValResult> { self.validate_list(false) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_tuple(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::TupleType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_tuple(&'a self) -> ValResult> { self.validate_tuple(false) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_set(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::SetType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_set(&'a self) -> ValResult> { self.validate_set(false) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_frozenset(&'a self) -> ValResult> { self.validate_frozenset(false) } @@ -499,7 +499,7 @@ impl<'a> Input<'a> for String { fn validate_date(&self, _strict: bool) -> ValResult { bytes_as_date(self, self.as_bytes()) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_date(&self) -> ValResult { self.validate_date(false) } @@ -511,7 +511,7 @@ impl<'a> Input<'a> for String { ) -> ValResult { bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_time( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, @@ -526,7 +526,7 @@ impl<'a> Input<'a> for String { ) -> ValResult { bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_datetime( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, @@ -541,7 +541,7 @@ impl<'a> Input<'a> for String { ) -> ValResult { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_timedelta( &self, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs index 7603eaf67..20a107669 100644 --- a/src/input/parse_json.rs +++ b/src/input/parse_json.rs @@ -57,7 +57,7 @@ impl<'de> Deserialize<'de> for JsonInput { impl<'de> Visitor<'de> for JsonVisitor { type Value = JsonInput; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("any valid JSON value") } @@ -92,12 +92,12 @@ impl<'de> Deserialize<'de> for JsonInput { Ok(JsonInput::String(value)) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn visit_none(self) -> Result { unreachable!() } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn visit_some(self, _: D) -> Result where D: serde::Deserializer<'de>, @@ -200,7 +200,7 @@ impl<'de> DeserializeSeed<'de> for KeyDeserializer { impl<'de> Visitor<'de> for KeyDeserializer { type Value = String; - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("a string key") } @@ -212,7 +212,7 @@ impl<'de> Visitor<'de> for KeyDeserializer { Ok(s.to_string()) } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn visit_string(self, _: String) -> Result where E: serde::de::Error, diff --git a/src/lib.rs b/src/lib.rs index 56efbc87f..cbf668b9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![cfg_attr(has_no_coverage, feature(no_coverage))] +#![cfg_attr(has_coverage_attribute, feature(coverage_attribute))] extern crate core; diff --git a/src/tools.rs b/src/tools.rs index 71f7ad60e..3c75decf1 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -48,7 +48,7 @@ impl<'py> SchemaDict<'py> for Option<&PyDict> { } } - #[cfg_attr(has_no_coverage, no_coverage)] + #[cfg_attr(has_coverage_attribute, coverage(off))] fn get_as_req(&'py self, key: &PyString) -> PyResult where T: FromPyObject<'py>, From 502fb9adfa96208236f0002b9e2d2cea8c7dd3d9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:21:56 +0000 Subject: [PATCH 032/550] Update pytest-codspeed requirement from ~=2.1.0 to ~=2.2.0 (#953) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 6ea2d16cf..ce5cc9f23 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -5,7 +5,7 @@ hypothesis==6.79.4 pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' pytest==7.4.2 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) -pytest-codspeed~=2.1.0; implementation_name == "cpython" and platform_machine == 'x86_64' +pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-examples==0.0.10 pytest-speed==0.3.5 pytest-mock==3.11.1 From e8015318b9f5ff64a0672f84bc51231a9438f307 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:34:32 +0100 Subject: [PATCH 033/550] Bump black from 23.7.0 to 23.9.1 (#955) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 8e7b5c66c..561e352fa 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -black==23.7.0 +black==23.9.1 griffe==0.36.2 pyright==1.1.325 ruff==0.0.287 From 245381f981ca3e2b146886e0ffc8011540ba7f6d Mon Sep 17 00:00:00 2001 From: zakstucke <44890343+zakstucke@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:23:55 +0200 Subject: [PATCH 034/550] Implementation of __cause__ for ValidationError using ExceptionGroups (#780) --- python/pydantic_core/core_schema.py | 3 + src/errors/validation_exception.rs | 101 ++++++++++++- src/validators/function.rs | 18 ++- src/validators/generator.rs | 29 +++- src/validators/mod.rs | 5 + tests/requirements.txt | 1 + tests/test_errors.py | 210 ++++++++++++++++++++++++++++ 7 files changed, 355 insertions(+), 12 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 74442b44c..0b4348e81 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -61,6 +61,8 @@ class CoreConfig(TypedDict, total=False): ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. hide_input_in_errors: Whether to hide input data from `ValidationError` representation. + validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. + Requires exceptiongroup backport pre Python 3.11. """ title: str @@ -92,6 +94,7 @@ class CoreConfig(TypedDict, total=False): ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8' # used to hide input data from ValidationError repr hide_input_in_errors: bool + validation_error_cause: bool # default: False IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 9dde6551c..e6563f597 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -3,10 +3,11 @@ use std::fmt::{Display, Write}; use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; -use pyo3::intern; +use pyo3::ffi; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; +use pyo3::{intern, AsPyPointer}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -51,6 +52,7 @@ impl ValidationError { error: ValError, outer_location: Option, hide_input: bool, + validation_error_cause: bool, ) -> PyErr { match error { ValError::LineErrors(raw_errors) => { @@ -61,9 +63,19 @@ impl ValidationError { .collect(), None => raw_errors.into_iter().map(|e| e.into_py(py)).collect(), }; + let validation_error = Self::new(line_errors, title, error_mode, hide_input); + match Py::new(py, validation_error) { - Ok(err) => PyErr::from_value(err.into_ref(py)), + Ok(err) => { + if validation_error_cause { + // Will return an import error if the backport was needed and not installed: + if let Some(cause_problem) = ValidationError::maybe_add_cause(err.borrow(py), py) { + return cause_problem; + } + } + PyErr::from_value(err.as_ref(py)) + } Err(err) => err, } } @@ -93,6 +105,91 @@ impl ValidationError { pub fn use_default_error() -> PyErr { py_schema_error_type!("Uncaught UseDefault error, please check your usage of `default` validators.") } + + fn maybe_add_cause(self_: PyRef<'_, Self>, py: Python) -> Option { + let mut user_py_errs = vec![]; + for line_error in &self_.line_errors { + if let ErrorType::AssertionError { + error: Some(err), + context: _, + } + | ErrorType::ValueError { + error: Some(err), + context: _, + } = &line_error.error_type + { + let note: PyObject = if let Location::Empty = &line_error.location { + "Pydantic: cause of loc: root".into_py(py) + } else { + format!( + "Pydantic: cause of loc: {}", + // Location formats with a newline at the end, hence the trim() + line_error.location.to_string().trim() + ) + .into_py(py) + }; + + // Notes only support 3.11 upwards: + #[cfg(Py_3_11)] + { + // Add the location context as a note, no direct c api for this, + // fine performance wise, add_note() goes directly to C: "(PyCFunction)BaseException_add_note": + // https://github.com/python/cpython/blob/main/Objects/exceptions.c + if err.call_method1(py, "add_note", (format!("\n{note}"),)).is_ok() { + user_py_errs.push(err.clone_ref(py)); + } + } + + // Pre 3.11 notes support, use a UserWarning exception instead: + #[cfg(not(Py_3_11))] + { + use pyo3::exceptions::PyUserWarning; + + let wrapped = PyUserWarning::new_err((note,)); + wrapped.set_cause(py, Some(PyErr::from_value(err.as_ref(py)))); + user_py_errs.push(wrapped); + } + } + } + + // Only add the cause if there are actually python user exceptions to show: + if !user_py_errs.is_empty() { + let title = "Pydantic User Code Exceptions"; + + // Native ExceptionGroup(s) only supported 3.11 and later: + #[cfg(Py_3_11)] + let cause = { + use pyo3::exceptions::PyBaseExceptionGroup; + Some(PyBaseExceptionGroup::new_err((title, user_py_errs)).into_py(py)) + }; + + // Pre 3.11 ExceptionGroup support, use the python backport instead: + // If something's gone wrong with the backport, just don't add the cause: + #[cfg(not(Py_3_11))] + let cause = { + use pyo3::exceptions::PyImportError; + match py.import("exceptiongroup") { + Ok(py_mod) => match py_mod.getattr("ExceptionGroup") { + Ok(group_cls) => match group_cls.call1((title, user_py_errs)) { + Ok(group_instance) => Some(group_instance.into_py(py)), + Err(_) => None, + }, + Err(_) => None, + }, + Err(_) => return Some(PyImportError::new_err("validation_error_cause flag requires the exceptiongroup module backport to be installed when used on Python <3.11.")), + } + }; + + // Set the cause to the ValidationError: + if let Some(cause) = cause { + unsafe { + // PyException_SetCause _steals_ a reference to cause, so must use .into_ptr() + ffi::PyException_SetCause(self_.as_ptr(), cause.into_ptr()); + } + } + } + None + } } static URL_ENV_VAR: GILOnceCell = GILOnceCell::new(); diff --git a/src/validators/function.rs b/src/validators/function.rs index fa8a0673b..206e10a0c 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -288,6 +288,7 @@ pub struct FunctionWrapValidator { field_name: Option>, info_arg: bool, hide_input_in_errors: bool, + validation_error_cause: bool, } impl BuildValidator for FunctionWrapValidator { @@ -302,6 +303,7 @@ impl BuildValidator for FunctionWrapValidator { let validator = build_validator(schema.get_as_req(intern!(py, "schema"))?, config, definitions)?; let function_info = destructure_function_schema(schema)?; let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); + let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { validator: Box::new(validator), func: function_info.function.clone(), @@ -313,6 +315,7 @@ impl BuildValidator for FunctionWrapValidator { field_name: function_info.field_name.clone(), info_arg: function_info.info_arg, hide_input_in_errors, + validation_error_cause, } .into()) } @@ -356,6 +359,7 @@ impl Validator for FunctionWrapValidator { &self.validator, state, self.hide_input_in_errors, + self.validation_error_cause, ), }; self._validate( @@ -381,6 +385,7 @@ impl Validator for FunctionWrapValidator { &self.validator, state, self.hide_input_in_errors, + self.validation_error_cause, ), updated_field_name: field_name.to_string(), updated_field_value: field_value.to_object(py), @@ -478,12 +483,12 @@ impl AssignmentValidatorCallable { } macro_rules! py_err_string { - ($error_value:expr, $type_member:ident, $input:ident) => { + ($py:expr, $py_err:expr, $error_value:expr, $type_member:ident, $input:ident) => { match $error_value.str() { Ok(py_string) => match py_string.to_str() { Ok(_) => ValError::new( ErrorType::$type_member { - error: Some($error_value.into()), + error: Some($py_err.into_py($py)), context: None, }, $input, @@ -499,17 +504,18 @@ macro_rules! py_err_string { /// as validation errors, `TypeError` is now considered as a runtime error to catch errors in function signatures pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> ValError<'a> { if err.is_instance_of::(py) { - if let Ok(pydantic_value_error) = err.value(py).extract::() { + let error_value = err.value(py); + if let Ok(pydantic_value_error) = error_value.extract::() { pydantic_value_error.into_val_error(input) - } else if let Ok(pydantic_error_type) = err.value(py).extract::() { + } else if let Ok(pydantic_error_type) = error_value.extract::() { pydantic_error_type.into_val_error(input) } else if let Ok(validation_error) = err.value(py).extract::() { validation_error.into_val_error(py) } else { - py_err_string!(err.value(py), ValueError, input) + py_err_string!(py, err, error_value, ValueError, input) } } else if err.is_instance_of::(py) { - py_err_string!(err.value(py), AssertionError, input) + py_err_string!(py, err, err.value(py), AssertionError, input) } else if err.is_instance_of::(py) { ValError::Omit } else if err.is_instance_of::(py) { diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 1047e31bd..c52910500 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -19,6 +19,7 @@ pub struct GeneratorValidator { max_length: Option, name: String, hide_input_in_errors: bool, + validation_error_cause: bool, } impl BuildValidator for GeneratorValidator { @@ -37,12 +38,16 @@ impl BuildValidator for GeneratorValidator { let hide_input_in_errors: bool = config .get_as(pyo3::intern!(schema.py(), "hide_input_in_errors"))? .unwrap_or(false); + let validation_error_cause: bool = config + .get_as(pyo3::intern!(schema.py(), "validation_error_cause"))? + .unwrap_or(false); Ok(Self { item_validator, name, min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?, max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?, hide_input_in_errors, + validation_error_cause, } .into()) } @@ -58,10 +63,16 @@ impl Validator for GeneratorValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let iterator = input.validate_iter()?; - let validator = self - .item_validator - .as_ref() - .map(|v| InternalValidator::new(py, "ValidatorIterator", v, state, self.hide_input_in_errors)); + let validator = self.item_validator.as_ref().map(|v| { + InternalValidator::new( + py, + "ValidatorIterator", + v, + state, + self.hide_input_in_errors, + self.validation_error_cause, + ) + }); let v_iterator = ValidatorIterator { iterator, @@ -69,6 +80,7 @@ impl Validator for GeneratorValidator { min_length: self.min_length, max_length: self.max_length, hide_input_in_errors: self.hide_input_in_errors, + validation_error_cause: self.validation_error_cause, }; Ok(v_iterator.into_py(py)) } @@ -105,6 +117,7 @@ struct ValidatorIterator { min_length: Option, max_length: Option, hide_input_in_errors: bool, + validation_error_cause: bool, } #[pymethods] @@ -117,6 +130,7 @@ impl ValidatorIterator { let min_length = slf.min_length; let max_length = slf.max_length; let hide_input_in_errors = slf.hide_input_in_errors; + let validation_error_cause = slf.validation_error_cause; let Self { validator, iterator, .. } = &mut *slf; @@ -143,6 +157,7 @@ impl ValidatorIterator { val_error, None, hide_input_in_errors, + validation_error_cause, )); } } @@ -169,6 +184,7 @@ impl ValidatorIterator { val_error, None, hide_input_in_errors, + validation_error_cause, )); } } @@ -217,6 +233,7 @@ pub struct InternalValidator { recursion_guard: RecursionGuard, validation_mode: InputType, hide_input_in_errors: bool, + validation_error_cause: bool, } impl fmt::Debug for InternalValidator { @@ -232,6 +249,7 @@ impl InternalValidator { validator: &CombinedValidator, state: &ValidationState, hide_input_in_errors: bool, + validation_error_cause: bool, ) -> Self { let extra = state.extra(); Self { @@ -246,6 +264,7 @@ impl InternalValidator { recursion_guard: state.recursion_guard.clone(), validation_mode: extra.mode, hide_input_in_errors, + validation_error_cause, } } @@ -277,6 +296,7 @@ impl InternalValidator { e, outer_location, self.hide_input_in_errors, + self.validation_error_cause, ) }) } @@ -305,6 +325,7 @@ impl InternalValidator { e, outer_location, self.hide_input_in_errors, + self.validation_error_cause, ) }) } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5824888ec..9a9c6a185 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -107,6 +107,7 @@ pub struct SchemaValidator { #[pyo3(get)] title: PyObject, hide_input_in_errors: bool, + validation_error_cause: bool, } #[pymethods] @@ -133,12 +134,14 @@ impl SchemaValidator { None => validator.get_name().into_py(py), }; let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); + let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { validator, definitions, schema: schema.into_py(py), title, hide_input_in_errors, + validation_error_cause, }) } @@ -329,6 +332,7 @@ impl SchemaValidator { error, None, self.hide_input_in_errors, + self.validation_error_cause, ) } } @@ -385,6 +389,7 @@ impl<'py> SelfValidator<'py> { schema: py.None(), title: "Self Schema".into_py(py), hide_input_in_errors: false, + validation_error_cause: false, }) } } diff --git a/tests/requirements.txt b/tests/requirements.txt index ce5cc9f23..ae8b9e50d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -14,3 +14,4 @@ pytest-timeout==2.1.0 pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' +exceptiongroup==1.1; python_version < "3.11" diff --git a/tests/test_errors.py b/tests/test_errors.py index 881abd508..fe71b9860 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,11 +1,15 @@ +import enum import re +import sys from decimal import Decimal from typing import Any, Optional +from unittest.mock import patch import pytest from dirty_equals import HasRepr, IsInstance, IsJson, IsStr from pydantic_core import ( + CoreConfig, PydanticCustomError, PydanticKnownError, PydanticOmit, @@ -517,6 +521,212 @@ def test_all_errors(): pytest.fail('core_schema.ErrorType needs to be updated') +@pytest.mark.skipif(sys.version_info < (3, 11), reason='This is the modern version used post 3.10.') +def test_validation_error_cause_contents(): + enabled_config: CoreConfig = {'validation_error_cause': True} + + def multi_raise_py_error(v: Any) -> Any: + try: + raise AssertionError('Wrong') + except AssertionError as e: + raise ValueError('Oh no!') from e + + s2 = SchemaValidator(core_schema.no_info_plain_validator_function(multi_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s2.validate_python('anything') + + cause_group = exc_info.value.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert cause.__notes__ + assert cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + # Edge case: make sure a deep inner ValidationError(s) causing a validator failure doesn't cause any problems: + def outer_raise_py_error(v: Any) -> Any: + try: + s2.validate_python('anything') + except ValidationError as e: + raise ValueError('Sub val failure') from e + + s3 = SchemaValidator(core_schema.no_info_plain_validator_function(outer_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s3.validate_python('anything') + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + assert len(exc_info.value.__cause__.exceptions) == 1 + cause = exc_info.value.__cause__.exceptions[0] + assert cause.__notes__ and cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Sub val failure')) + subcause = cause.__cause__ + assert isinstance(subcause, ValidationError) + + cause_group = subcause.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert cause.__notes__ + assert cause.__notes__[-1].startswith('\nPydantic: ') + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + +@pytest.mark.skipif(sys.version_info >= (3, 11), reason='This is the backport/legacy version used pre 3.11 only.') +def test_validation_error_cause_contents_legacy(): + from exceptiongroup import BaseExceptionGroup + + enabled_config: CoreConfig = {'validation_error_cause': True} + + def multi_raise_py_error(v: Any) -> Any: + try: + raise AssertionError('Wrong') + except AssertionError as e: + raise ValueError('Oh no!') from e + + s2 = SchemaValidator(core_schema.no_info_plain_validator_function(multi_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s2.validate_python('anything') + + cause_group = exc_info.value.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + # Make sure a deep inner ValidationError(s) causing a validator failure doesn't cause any problems: + def outer_raise_py_error(v: Any) -> Any: + try: + s2.validate_python('anything') + except ValidationError as e: + raise ValueError('Sub val failure') from e + + s3 = SchemaValidator(core_schema.no_info_plain_validator_function(outer_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s3.validate_python('anything') + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + assert len(exc_info.value.__cause__.exceptions) == 1 + cause = exc_info.value.__cause__.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Sub val failure')) + subcause = cause.__cause__ + assert isinstance(subcause, ValidationError) + + cause_group = subcause.__cause__ + assert isinstance(cause_group, BaseExceptionGroup) + assert len(cause_group.exceptions) == 1 + + cause = cause_group.exceptions[0] + assert repr(cause).startswith("UserWarning('Pydantic: ") + assert cause.__cause__ is not None + cause = cause.__cause__ + assert repr(cause) == repr(ValueError('Oh no!')) + assert cause.__traceback__ is not None + + sub_cause = cause.__cause__ + assert repr(sub_cause) == repr(AssertionError('Wrong')) + assert sub_cause.__cause__ is None + assert sub_cause.__traceback__ is not None + + +class CauseResult(enum.Enum): + CAUSE = enum.auto() + NO_CAUSE = enum.auto() + IMPORT_ERROR = enum.auto() + + +@pytest.mark.parametrize( + 'desc,config,expected_result', + [ # Without the backport should still work after 3.10 as not needed: + ( + 'Enabled', + {'validation_error_cause': True}, + CauseResult.CAUSE if sys.version_info >= (3, 11) else CauseResult.IMPORT_ERROR, + ), + ('Disabled specifically', {'validation_error_cause': False}, CauseResult.NO_CAUSE), + ('Disabled implicitly', {}, CauseResult.NO_CAUSE), + ], +) +def test_validation_error_cause_config_variants(desc: str, config: CoreConfig, expected_result: CauseResult): + # Simulate the package being missing: + with patch.dict('sys.modules', {'exceptiongroup': None}): + + def singular_raise_py_error(v: Any) -> Any: + raise ValueError('Oh no!') + + s = SchemaValidator(core_schema.no_info_plain_validator_function(singular_raise_py_error), config=config) + + if expected_result is CauseResult.IMPORT_ERROR: + # Confirm error message contains "requires the exceptiongroup module" in the middle of the string: + with pytest.raises(ImportError, match='requires the exceptiongroup module'): + s.validate_python('anything') + elif expected_result is CauseResult.CAUSE: + with pytest.raises(ValidationError) as exc_info: + s.validate_python('anything') + assert exc_info.value.__cause__ is not None + assert hasattr(exc_info.value.__cause__, 'exceptions') + assert len(exc_info.value.__cause__.exceptions) == 1 + assert repr(exc_info.value.__cause__.exceptions[0]) == repr(ValueError('Oh no!')) + elif expected_result is CauseResult.NO_CAUSE: + with pytest.raises(ValidationError) as exc_info: + s.validate_python('anything') + assert exc_info.value.__cause__ is None + else: + raise AssertionError('Unhandled result: {}'.format(expected_result)) + + +def test_validation_error_cause_traceback_preserved(): + """Makes sure historic bug of traceback being lost is fixed.""" + + enabled_config: CoreConfig = {'validation_error_cause': True} + + def singular_raise_py_error(v: Any) -> Any: + raise ValueError('Oh no!') + + s1 = SchemaValidator(core_schema.no_info_plain_validator_function(singular_raise_py_error), config=enabled_config) + with pytest.raises(ValidationError) as exc_info: + s1.validate_python('anything') + + base_errs = getattr(exc_info.value.__cause__, 'exceptions', []) + assert len(base_errs) == 1 + base_err = base_errs[0] + + # Get to the root error: + cause = base_err + while cause.__cause__ is not None: + cause = cause.__cause__ + + # Should still have a traceback: + assert cause.__traceback__ is not None + + class BadRepr: def __repr__(self): raise RuntimeError('bad repr') From e4eed14b7b8af6f71f3bce7b4ff2f3f034191d55 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 13:03:32 +0100 Subject: [PATCH 035/550] Bump pyright from 1.1.325 to 1.1.327 (#970) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 561e352fa..eb8321437 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.2 -pyright==1.1.325 +pyright==1.1.327 ruff==0.0.287 mypy==1.5.1 From a107c96a00f87154c4da2cf60d4ba9462a87339f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:07:28 +0000 Subject: [PATCH 036/550] Bump serde_json from 1.0.106 to 1.0.107 (#971) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c8b1d58f..2986812aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -422,9 +422,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.106" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index 3c7e18d4f..75fc7ae17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } regex = "1.9.5" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" -serde_json = {version = "1.0.106", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.188", features = ["derive"] } speedate = "0.12.0" From 1d102381b04dd53e61665d84281161fc351ffbee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Sep 2023 12:10:10 +0000 Subject: [PATCH 037/550] Bump ruff from 0.0.287 to 0.0.290 (#968) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index eb8321437..ff8f85657 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.2 pyright==1.1.327 -ruff==0.0.287 +ruff==0.0.290 mypy==1.5.1 From a17e9b632eb25e67a0b668a33cfd6dc79b6392db Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:27:43 -0500 Subject: [PATCH 038/550] Populate defs from defs schema (#972) Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- Cargo.lock | 2 +- Cargo.toml | 2 +- generate_self_schema.py | 41 +- python/pydantic_core/core_schema.py | 9 +- src/definitions.rs | 3 +- src/serializers/shared.rs | 8 - .../type_serializers/definitions.rs | 15 +- src/validators/definitions.rs | 7 +- src/validators/mod.rs | 8 - tests/benchmarks/complete_schema.py | 34 +- tests/benchmarks/test_micro_benchmarks.py | 44 +- tests/serializers/test_definitions.py | 27 +- .../serializers/test_definitions_recursive.py | 144 ++--- tests/test.rs | 47 +- tests/test_build.py | 1 + tests/test_garbage_collection.py | 30 +- tests/test_hypothesis.py | 30 +- tests/test_json.py | 27 +- tests/test_schema_functions.py | 8 +- tests/test_typing.py | 34 +- tests/validators/test_definitions.py | 28 +- .../validators/test_definitions_recursive.py | 586 +++++++++--------- tests/validators/test_tagged_union.py | 31 +- tests/validators/test_union.py | 17 +- 25 files changed, 607 insertions(+), 578 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b54ac0d8..85174fa41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -207,7 +207,7 @@ jobs: run: | pip install pdm maturin pdm venv create --with-pip - pdm install -G testing + pdm install -G testing -G email pdm run pip install maturin pdm run bash -c 'cd ../pydantic-core && make build-dev' working-directory: pydantic diff --git a/Cargo.lock b/Cargo.lock index 2986812aa..55224a954 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.7.0" +version = "2.8.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 75fc7ae17..a4dfda154 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.7.0" +version = "2.8.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/generate_self_schema.py b/generate_self_schema.py index a745471d8..2c190bbad 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -14,7 +14,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union -from typing_extensions import get_args, get_origin, is_typeddict +from typing_extensions import TypedDict, get_args, get_origin, is_typeddict TypingUnionType = Type[Union[str, int]] @@ -45,13 +45,13 @@ schema_ref_validator = {'type': 'definition-ref', 'schema_ref': 'root-schema'} -def get_schema(obj) -> core_schema.CoreSchema: +def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: if isinstance(obj, str): return {'type': obj} elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal): return {'type': obj.__name__.lower()} elif is_typeddict(obj): - return type_dict_schema(obj) + return type_dict_schema(obj, definitions) elif obj == Any or obj == type: return {'type': 'any'} if isinstance(obj, type) and issubclass(obj, core_schema.Protocol): @@ -60,7 +60,7 @@ def get_schema(obj) -> core_schema.CoreSchema: origin = get_origin(obj) assert origin is not None, f'origin cannot be None, obj={obj}, you probably need to fix generate_self_schema.py' if origin is Union or origin is TypesUnionType: - return union_schema(obj) + return union_schema(obj, definitions) elif obj is Callable or origin is Callable: return {'type': 'callable'} elif origin is core_schema.Literal: @@ -68,14 +68,14 @@ def get_schema(obj) -> core_schema.CoreSchema: assert expected, f'literal "expected" cannot be empty, obj={obj}' return {'type': 'literal', 'expected': expected} elif issubclass(origin, List): - return {'type': 'list', 'items_schema': get_schema(obj.__args__[0])} + return {'type': 'list', 'items_schema': get_schema(obj.__args__[0], definitions)} elif issubclass(origin, Set): - return {'type': 'set', 'items_schema': get_schema(obj.__args__[0])} + return {'type': 'set', 'items_schema': get_schema(obj.__args__[0], definitions)} elif issubclass(origin, Dict): return { 'type': 'dict', - 'keys_schema': get_schema(obj.__args__[0]), - 'values_schema': get_schema(obj.__args__[1]), + 'keys_schema': get_schema(obj.__args__[0], definitions), + 'values_schema': get_schema(obj.__args__[1], definitions), } elif issubclass(origin, Type): # can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators @@ -107,7 +107,9 @@ def tagged_union(std_union_schema: Dict[str, Any], discriminator_key: str, ref: defined_ser_schema = False -def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901 +def type_dict_schema( # noqa: C901 + typed_dict: type[TypedDict], definitions: dict[str, core_schema.CoreSchema] +) -> dict[str, Any]: global defined_ser_schema required_keys = getattr(typed_dict, '__required_keys__', set()) @@ -154,13 +156,14 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901 required = True field_type = field_type.__args__[0] - schema = get_schema(field_type) + schema = get_schema(field_type, definitions) if fr_arg == 'SerSchema': if defined_ser_schema: schema = {'type': 'definition-ref', 'schema_ref': 'ser-schema'} else: defined_ser_schema = True - schema = tagged_union(schema, 'type', 'ser-schema') + definitions['ser-schema'] = tagged_union(schema, 'type', 'ser-schema') + schema = {'type': 'definition-ref', 'schema_ref': 'ser-schema'} elif fr_arg.endswith('SerSchema'): schema = tagged_union(schema, 'type') @@ -172,8 +175,8 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901 return {'type': 'typed-dict', 'fields': fields, 'extra_behavior': 'forbid'} -def union_schema(union_type: UnionType) -> core_schema.UnionSchema | core_schema.DefinitionReferenceSchema: - return {'type': 'union', 'choices': [get_schema(arg) for arg in union_type.__args__]} +def union_schema(union_type: UnionType, definitions) -> core_schema.UnionSchema | core_schema.DefinitionReferenceSchema: + return {'type': 'union', 'choices': [get_schema(arg, definitions) for arg in union_type.__args__]} def all_literal_values(type_: type[core_schema.Literal]) -> list[any]: @@ -196,16 +199,24 @@ def main() -> None: schema_union = core_schema.CoreSchema assert get_origin(schema_union) is Union, 'expected core_schema.CoreSchema to be a Union' + definitions: dict[str, core_schema.CoreSchema] = {} + choices = {} for s in schema_union.__args__: type_ = s.__annotations__['type'] m = re.search(r"Literal\['(.+?)']", type_.__forward_arg__) assert m, f'Unknown schema type: {type_}' key = m.group(1) - value = get_schema(s) + value = get_schema(s, definitions) choices[key] = value - schema = {'type': 'tagged-union', 'ref': 'root-schema', 'discriminator': 'type', 'choices': choices} + schema = core_schema.definitions_schema( + schema=core_schema.definition_reference_schema(schema_ref='root-schema'), + definitions=[ + core_schema.tagged_union_schema(choices, discriminator='type', ref='root-schema'), + *definitions.values(), + ], + ) python_code = ( f'# this file is auto-generated by generate_self_schema.py, DO NOT edit manually\nself_schema = {schema}\n' ) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 0b4348e81..db0b1bd54 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3772,9 +3772,14 @@ def definition_reference_schema( from pydantic_core import SchemaValidator, core_schema schema_definition = core_schema.definition_reference_schema('list-schema') - schema = core_schema.list_schema(items_schema=schema_definition, ref='list-schema') + schema = core_schema.definitions_schema( + schema=schema_definition, + definitions=[ + core_schema.list_schema(items_schema=schema_definition, ref='list-schema'), + ], + ) v = SchemaValidator(schema) - assert v.validate_python([[]]) == [[]] + assert v.validate_python([()]) == [[]] ``` Args: diff --git a/src/definitions.rs b/src/definitions.rs index 44360c83f..0d01fd2ae 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -82,7 +82,8 @@ impl DefinitionsBuilder { } /// Retrieve an item definition using a ReferenceId - /// Will raise an error if the definition for that reference does not yet exist + /// If the definition doesn't yet exist (as happens in recursive types) then we create it + /// At the end (in finish()) we check that there are no undefined definitions pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> { let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) { Some(v) => v, diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 0c97ee635..7c24ff6db 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -21,7 +21,6 @@ use super::errors::se_err_py_err; use super::extra::Extra; use super::infer::infer_json_key; use super::ob_type::{IsType, ObType}; -use super::type_serializers::definitions::DefinitionRefSerializer; pub(crate) trait BuildSerializer: Sized { const EXPECTED_TYPE: &'static str; @@ -207,13 +206,6 @@ impl BuildSerializer for CombinedSerializer { config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let py: Python = schema.py(); - if let Some(schema_ref) = schema.get_as::(intern!(py, "ref"))? { - let inner_ser = Self::_build(schema, config, definitions)?; - let ser_id = definitions.add_definition(schema_ref, inner_ser)?; - return Ok(DefinitionRefSerializer::from_id(ser_id)); - } - Self::_build(schema, config, definitions) } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4d2a0371c..19d6e75ed 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -25,9 +25,12 @@ impl BuildSerializer for DefinitionsSerializerBuilder { let schema_definitions: &PyList = schema.get_as_req(intern!(py, "definitions"))?; - for schema_def in schema_definitions { - CombinedSerializer::build(schema_def.downcast()?, config, definitions)?; - // no need to store the serializer here, it has already been stored in definitions if necessary + for schema_definition in schema_definitions { + let reference = schema_definition + .extract::<&PyDict>()? + .get_as_req::(intern!(py, "ref"))?; + let serializer = CombinedSerializer::build(schema_definition.downcast()?, config, definitions)?; + definitions.add_definition(reference, serializer)?; } let inner_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; @@ -40,12 +43,6 @@ pub struct DefinitionRefSerializer { serializer_id: usize, } -impl DefinitionRefSerializer { - pub fn from_id(serializer_id: usize) -> CombinedSerializer { - Self { serializer_id }.into() - } -} - impl BuildSerializer for DefinitionRefSerializer { const EXPECTED_TYPE: &'static str = "definition-ref"; diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 0dfc80475..3a35fce4c 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -25,8 +25,11 @@ impl BuildValidator for DefinitionsValidatorBuilder { let schema_definitions: &PyList = schema.get_as_req(intern!(py, "definitions"))?; for schema_definition in schema_definitions { - build_validator(schema_definition, config, definitions)?; - // no need to store the validator here, it has already been stored in definitions if necessary + let reference = schema_definition + .extract::<&PyDict>()? + .get_as_req::(intern!(py, "ref"))?; + let validator = build_validator(schema_definition, config, definitions)?; + definitions.add_definition(reference, validator)?; } let inner_schema: &PyAny = schema.get_as_req(intern!(py, "schema"))?; diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 9a9c6a185..9440f1527 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -60,7 +60,6 @@ mod with_default; pub use with_default::DefaultType; -use self::definitions::DefinitionRefValidator; pub use self::validation_state::ValidationState; #[pyclass(module = "pydantic_core._pydantic_core", name = "Some")] @@ -413,13 +412,6 @@ fn build_specific_validator<'a, T: BuildValidator>( config: Option<&'a PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let py = schema_dict.py(); - if let Some(schema_ref) = schema_dict.get_as::(intern!(py, "ref"))? { - let inner_val = T::build(schema_dict, config, definitions)?; - let validator_id = definitions.add_definition(schema_ref, inner_val)?; - return Ok(DefinitionRefValidator::new(validator_id).into()); - } - T::build(schema_dict, config, definitions) .map_err(|err| py_schema_error_type!("Error building \"{}\" validator:\n {}", val_type, err)) } diff --git a/tests/benchmarks/complete_schema.py b/tests/benchmarks/complete_schema.py index 69a27277b..d4eff16b2 100644 --- a/tests/benchmarks/complete_schema.py +++ b/tests/benchmarks/complete_schema.py @@ -191,22 +191,28 @@ def wrap_function(input_value, validator, info): 'field_recursive': { 'type': 'model-field', 'schema': { - 'ref': 'Branch', - 'type': 'typed-dict', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'nullable', - 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, + 'type': 'definitions', + 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, + 'definitions': [ + { + 'type': 'typed-dict', + 'fields': { + 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, + 'sub_branch': { + 'type': 'typed-dict-field', + 'schema': { + 'type': 'default', + 'schema': { + 'type': 'nullable', + 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, + }, + 'default': None, + }, }, - 'default': None, }, - }, - }, + 'ref': 'Branch', + } + ], }, }, }, diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 6c9de2fb2..7e6799a74 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -298,28 +298,28 @@ class CoreBranch: __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' v = SchemaValidator( - { - 'ref': 'Branch', - 'type': 'model', - 'cls': CoreBranch, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'width': {'type': 'model-field', 'schema': {'type': 'int'}}, - 'branch': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'nullable', - 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, - }, - 'default': None, - }, - }, - }, - }, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema(schema_ref='Branch'), + [ + core_schema.model_schema( + CoreBranch, + core_schema.model_fields_schema( + { + 'width': core_schema.model_field(core_schema.int_schema()), + 'branch': core_schema.model_field( + core_schema.with_default_schema( + core_schema.nullable_schema( + core_schema.definition_reference_schema(schema_ref='Branch') + ), + default=None, + ) + ), + } + ), + ref='Branch', + ) + ], + ) ) benchmark(v.validate_python, definition_model_data) diff --git a/tests/serializers/test_definitions.py b/tests/serializers/test_definitions.py index b52e1e27f..61b1e9cf7 100644 --- a/tests/serializers/test_definitions.py +++ b/tests/serializers/test_definitions.py @@ -42,9 +42,12 @@ def test_repeated_ref(): SchemaSerializer( core_schema.tuple_positional_schema( [ - core_schema.int_schema(ref='foobar'), - core_schema.definition_reference_schema('foobar'), - core_schema.int_schema(ref='foobar'), + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), ] ) ) @@ -53,14 +56,16 @@ def test_repeated_ref(): def test_repeat_after(): with pytest.raises(SchemaError, match='SchemaError: Duplicate ref: `foobar`'): SchemaSerializer( - core_schema.tuple_positional_schema( - [ - core_schema.definitions_schema( - core_schema.list_schema(core_schema.definition_reference_schema('foobar')), - [core_schema.int_schema(ref='foobar')], - ), - core_schema.int_schema(ref='foobar'), - ] + core_schema.definitions_schema( + core_schema.tuple_positional_schema( + [ + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), + core_schema.definition_reference_schema('foobar'), + ] + ), + [core_schema.int_schema(ref='foobar')], ) ) diff --git a/tests/serializers/test_definitions_recursive.py b/tests/serializers/test_definitions_recursive.py index 8d5a80065..014af35c0 100644 --- a/tests/serializers/test_definitions_recursive.py +++ b/tests/serializers/test_definitions_recursive.py @@ -1,21 +1,24 @@ import pytest -from pydantic_core import SchemaSerializer +from pydantic_core import SchemaSerializer, core_schema def test_branch_nullable(): s = SchemaSerializer( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - }, - }, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'name': core_schema.typed_dict_field(core_schema.str_schema()), + 'sub_branch': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('Branch')) + ), + }, + ref='Branch', + ) + ], + ) ) assert s.to_python({'name': 'root', 'sub_branch': {'name': 'branch', 'sub_branch': None}}) == { 'name': 'root', @@ -29,17 +32,20 @@ def test_branch_nullable(): def test_cyclic_recursion(): s = SchemaSerializer( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - }, - }, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'name': core_schema.typed_dict_field(core_schema.str_schema()), + 'sub_branch': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('Branch')) + ), + }, + ref='Branch', + ) + ], + ) ) v = {'name': 'root'} v['sub_branch'] = v @@ -53,24 +59,24 @@ def test_cyclic_recursion(): def test_custom_ser(): s = SchemaSerializer( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'nullable', - 'schema': { - 'type': 'definition-ref', - 'schema_ref': 'Branch', - 'serialization': {'type': 'to-string', 'when_used': 'always'}, - }, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'name': core_schema.typed_dict_field(core_schema.str_schema()), + 'sub_branch': core_schema.typed_dict_field( + core_schema.nullable_schema( + core_schema.definition_reference_schema( + 'Branch', serialization=core_schema.to_string_ser_schema(when_used='always') + ) + ) + ), }, - }, - }, - } + ref='Branch', + ) + ], + ) ) assert s.to_python({'name': 'root', 'sub_branch': {'name': 'branch', 'sub_branch': None}}) == { 'name': 'root', @@ -80,43 +86,39 @@ def test_custom_ser(): def test_recursive_function(): s = SchemaSerializer( - { - 'type': 'typed-dict', - 'fields': { - 'root': {'type': 'typed-dict-field', 'schema': {'type': 'definition-ref', 'schema_ref': 'my_ref'}} - }, - 'ref': 'my_ref', - 'serialization': {'type': 'function-wrap', 'info_arg': True, 'function': lambda x, _1, _2: x}, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('my_ref'), + [ + core_schema.typed_dict_schema( + {'root': core_schema.typed_dict_field(core_schema.definition_reference_schema('my_ref'))}, + ref='my_ref', + serialization=core_schema.wrap_serializer_function_ser_schema(function=lambda x, _handler: x), + ) + ], + ) ) assert s.to_python({'root': {'root': {}}}) == {'root': {'root': {}}} def test_recursive_function_deeper_ref(): s = SchemaSerializer( - { - 'type': 'typed-dict', - 'fields': { - 'a': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'typed-dict', - 'ref': 'my_ref', - 'fields': { - 'b': { - 'type': 'typed-dict-field', - 'schema': {'type': 'definition-ref', 'schema_ref': 'my_ref'}, - } - }, - }, - } - }, - 'serialization': { - 'type': 'function-wrap', - 'is_field_serializer': False, - 'info_arg': True, - 'function': lambda x, _1, _2: x, + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.definitions_schema( + core_schema.definition_reference_schema('my_ref'), + [ + core_schema.typed_dict_schema( + {'b': core_schema.typed_dict_field(core_schema.definition_reference_schema('my_ref'))}, + ref='my_ref', + ) + ], + ) + ) }, - } + serialization=core_schema.wrap_serializer_function_ser_schema( + function=lambda x, _handler: x, is_field_serializer=False + ), + ) ) assert s.to_python({'a': {'b': {'b': {}}}}) == {'a': {'b': {'b': {}}}} diff --git a/tests/test.rs b/tests/test.rs index fc3456eec..9b2fb99b5 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,22 +7,43 @@ mod tests { #[test] fn test_build_schema_serializer() { Python::with_gil(|py| { + // 'type': 'typed-dict', + // 'fields': { + // 'root': { + // 'type': 'typed-dict-field', + // 'schema': { + // 'type': 'definition-ref', + // 'schema_ref': 'C-ref', + // }, + // }, + // }, + // 'ref': 'C-ref', + // 'serialization': { + // 'type': 'function-wrap', + // 'function': lambda: None, + // }, let code = r#"{ - 'type': 'typed-dict', - 'fields': { - 'root': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'definition-ref', - 'schema_ref': 'C-ref', + 'type': 'definitions', + 'schema': {'type': 'definition-ref', 'schema_ref': 'C-ref'}, + 'definitions': [ + { + 'type': 'typed-dict', + 'fields': { + 'root': { + 'type': 'typed-dict-field', + 'schema': { + 'type': 'definition-ref', + 'schema_ref': 'C-ref', + } + }, + }, + 'ref': 'C-ref', + 'serialization': { + 'type': 'function-wrap', + 'function': lambda: None, }, }, - }, - 'ref': 'C-ref', - 'serialization': { - 'type': 'function-wrap', - 'function': lambda: None, - }, + ] }"#; let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap(); SchemaSerializer::py_new(py, schema, None).unwrap(); diff --git a/tests/test_build.py b/tests/test_build.py index 7cc0b69af..095eb6887 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -61,6 +61,7 @@ def test_pickle(pickle_protocol: int) -> None: assert repr(v1) == repr(v2) +@pytest.mark.skip def test_schema_definition_error(): schema = {'type': 'union', 'choices': []} schema['choices'].append({'type': 'nullable', 'schema': schema}) diff --git a/tests/test_garbage_collection.py b/tests/test_garbage_collection.py index 9dec6e9ea..d848c91ea 100644 --- a/tests/test_garbage_collection.py +++ b/tests/test_garbage_collection.py @@ -7,6 +7,16 @@ from pydantic_core import SchemaSerializer, SchemaValidator, core_schema +GC_TEST_SCHEMA_INNER = core_schema.definitions_schema( + core_schema.definition_reference_schema(schema_ref='model'), + [ + core_schema.typed_dict_schema( + {'x': core_schema.typed_dict_field(core_schema.definition_reference_schema(schema_ref='model'))}, + ref='model', + ) + ], +) + @pytest.mark.xfail( condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899' @@ -17,15 +27,7 @@ class BaseModel: __schema__: SchemaSerializer def __init_subclass__(cls) -> None: - cls.__schema__ = SchemaSerializer( - core_schema.model_schema( - cls, - core_schema.typed_dict_schema( - {'x': core_schema.typed_dict_field(core_schema.definition_reference_schema('model'))} - ), - ref='model', - ) - ) + cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() @@ -54,15 +56,7 @@ class BaseModel: __validator__: SchemaValidator def __init_subclass__(cls) -> None: - cls.__validator__ = SchemaValidator( - core_schema.model_schema( - cls, - core_schema.typed_dict_schema( - {'x': core_schema.typed_dict_field(core_schema.definition_reference_schema('model'))} - ), - ref='model', - ) - ) + cls.__validator__ = SchemaValidator(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() diff --git a/tests/test_hypothesis.py b/tests/test_hypothesis.py index e925d56b7..ea0d67f0d 100644 --- a/tests/test_hypothesis.py +++ b/tests/test_hypothesis.py @@ -10,6 +10,7 @@ from typing_extensions import TypedDict from pydantic_core import SchemaSerializer, SchemaValidator, ValidationError +from pydantic_core import core_schema as cs @pytest.fixture(scope='module') @@ -58,21 +59,22 @@ def test_datetime_binary(datetime_schema, data): @pytest.fixture(scope='module') def definition_schema(): return SchemaValidator( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - 'default': None, + cs.definitions_schema( + cs.definition_reference_schema('Branch'), + [ + cs.typed_dict_schema( + { + 'name': cs.typed_dict_field(cs.str_schema()), + 'sub_branch': cs.typed_dict_field( + cs.with_default_schema( + cs.nullable_schema(cs.definition_reference_schema('Branch')), default=None + ) + ), }, - }, - }, - } + ref='Branch', + ) + ], + ) ) diff --git a/tests/test_json.py b/tests/test_json.py index 272fb5dcb..9bba05c14 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -254,18 +254,23 @@ def __init__(self, my_foo: int, my_inners: List['Foobar']): # force a recursive model to ensure we exercise the transfer of definitions from the loaded # serializer - c = core_schema.model_schema( - Foobar, - core_schema.typed_dict_schema( - { - 'my_foo': core_schema.typed_dict_field(core_schema.int_schema(), serialization_alias='myFoo'), - 'my_inners': core_schema.typed_dict_field( - core_schema.list_schema(core_schema.definition_reference_schema('foobar')), - serialization_alias='myInners', + c = core_schema.definitions_schema( + core_schema.definition_reference_schema(schema_ref='foobar'), + [ + core_schema.model_schema( + Foobar, + core_schema.typed_dict_schema( + { + 'my_foo': core_schema.typed_dict_field(core_schema.int_schema(), serialization_alias='myFoo'), + 'my_inners': core_schema.typed_dict_field( + core_schema.list_schema(core_schema.definition_reference_schema('foobar')), + serialization_alias='myInners', + ), + } ), - } - ), - ref='foobar', + ref='foobar', + ) + ], ) v = SchemaValidator(c) s = SchemaSerializer(c) diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 512a1977d..c119ca567 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -255,8 +255,12 @@ def args(*args, **kwargs): (core_schema.is_subclass_schema, args(MyModel), {'type': 'is-subclass', 'cls': MyModel}), ( core_schema.definitions_schema, - args({'type': 'int'}, [{'type': 'int'}]), - {'type': 'definitions', 'schema': {'type': 'int'}, 'definitions': [{'type': 'int'}]}, + args({'type': 'definition-ref', 'schema_ref': 'an-int'}, [{'type': 'int', 'ref': 'an-int'}]), + { + 'type': 'definitions', + 'schema': {'type': 'definition-ref', 'schema_ref': 'an-int'}, + 'definitions': [{'type': 'int', 'ref': 'an-int'}], + }, ), (core_schema.definition_reference_schema, args('foo'), {'type': 'definition-ref', 'schema_ref': 'foo'}), ( diff --git a/tests/test_typing.py b/tests/test_typing.py index ea51029a4..55c3732e8 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -118,22 +118,28 @@ def test_schema_typing() -> None: schema: CoreSchema = {'type': 'function-plain', 'function': {'type': 'general', 'function': validator}} SchemaValidator(schema) schema: CoreSchema = { - 'ref': 'Branch', - 'type': 'typed-dict', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'definition-ref', 'schema_ref': 'Branch'}], + 'type': 'definitions', + 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, + 'definitions': [ + { + 'type': 'typed-dict', + 'fields': { + 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, + 'sub_branch': { + 'type': 'typed-dict-field', + 'schema': { + 'type': 'default', + 'schema': { + 'type': 'nullable', + 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}, + }, + 'default': None, + }, }, - 'default': None, }, - }, - }, + 'ref': 'Branch', + } + ], } SchemaValidator(schema) schema: CoreSchema = {'type': 'date', 'le': date.today()} diff --git a/tests/validators/test_definitions.py b/tests/validators/test_definitions.py index bff170890..b6ac5d133 100644 --- a/tests/validators/test_definitions.py +++ b/tests/validators/test_definitions.py @@ -78,10 +78,12 @@ def test_repeated_ref(): SchemaValidator( core_schema.tuple_positional_schema( [ - core_schema.int_schema(ref='foobar'), - # the definition has to be used for it to go into slots/reusable and therefore trigger the error - core_schema.definition_reference_schema('foobar'), - core_schema.int_schema(ref='foobar'), + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), ] ) ) @@ -90,14 +92,16 @@ def test_repeated_ref(): def test_repeat_after(): with pytest.raises(SchemaError, match='SchemaError: Duplicate ref: `foobar`'): SchemaValidator( - core_schema.tuple_positional_schema( - [ - core_schema.definitions_schema( - core_schema.list_schema(core_schema.definition_reference_schema('foobar')), - [core_schema.int_schema(ref='foobar')], - ), - core_schema.int_schema(ref='foobar'), - ] + core_schema.definitions_schema( + core_schema.tuple_positional_schema( + [ + core_schema.definitions_schema( + core_schema.definition_reference_schema('foobar'), [core_schema.int_schema(ref='foobar')] + ), + core_schema.definition_reference_schema('foobar'), + ] + ), + [core_schema.int_schema(ref='foobar')], ) ) diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 06ee6c460..f23999cfa 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -13,39 +13,6 @@ def test_branch_nullable(): - v = SchemaValidator( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - 'default': None, - }, - }, - }, - } - ) - - assert v.validate_python({'name': 'root'}) == {'name': 'root', 'sub_branch': None} - assert plain_repr(v).startswith( - 'SchemaValidator(title="typed-dict",validator=DefinitionRef(DefinitionRefValidator{' - ) - assert ',definitions=[TypedDict(TypedDictValidator{' in plain_repr(v) - - assert v.validate_python({'name': 'root', 'sub_branch': {'name': 'b1'}}) == ( - {'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': None}} - ) - assert v.validate_python({'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': {'name': 'b2'}}}) == ( - {'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': {'name': 'b2', 'sub_branch': None}}} - ) - - -def test_branch_nullable_definitions(): v = SchemaValidator( core_schema.definitions_schema( {'type': 'definition-ref', 'schema_ref': 'Branch'}, @@ -99,24 +66,25 @@ def test_unused_ref(): def test_nullable_error(): v = SchemaValidator( - { - 'ref': 'Branch', - 'type': 'typed-dict', - 'fields': { - 'width': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, - 'sub_branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'definition-ref', 'schema_ref': 'Branch'}], - }, - 'default': None, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'width': core_schema.typed_dict_field(core_schema.int_schema()), + 'sub_branch': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.union_schema( + [core_schema.none_schema(), core_schema.definition_reference_schema('Branch')] + ), + default=None, + ) + ), }, - }, - }, - } + ref='Branch', + ) + ], + ) ) assert v.validate_python({'width': 123, 'sub_branch': {'width': 321}}) == ( {'width': 123, 'sub_branch': {'width': 321, 'sub_branch': None}} @@ -141,24 +109,23 @@ def test_nullable_error(): def test_list(): v = SchemaValidator( - { - 'type': 'typed-dict', - 'ref': 'BranchList', - 'fields': { - 'width': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, - 'branches': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'list', - 'items_schema': {'type': 'definition-ref', 'schema_ref': 'BranchList'}, - }, - 'default': None, + core_schema.definitions_schema( + core_schema.definition_reference_schema('BranchList'), + [ + core_schema.typed_dict_schema( + { + 'width': core_schema.typed_dict_field(core_schema.int_schema()), + 'branches': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.list_schema(core_schema.definition_reference_schema('BranchList')), + default=None, + ) + ), }, - }, - }, - } + ref='BranchList', + ) + ], + ) ) assert v.validate_python({'width': 1, 'branches': [{'width': 2}, {'width': 3, 'branches': [{'width': 4}]}]}) == ( { @@ -183,45 +150,37 @@ class Bar: """ v = SchemaValidator( - { - 'ref': 'Foo', - 'type': 'typed-dict', - 'fields': { - 'height': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, - 'bar': { - 'type': 'typed-dict-field', - 'schema': { - 'ref': 'Bar', - 'type': 'typed-dict', - 'fields': { - 'width': {'type': 'typed-dict-field', 'schema': {'type': 'int'}}, - 'bars': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'list', - 'items_schema': {'type': 'definition-ref', 'schema_ref': 'Bar'}, - }, - 'default': None, - }, - }, - 'foo': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'definition-ref', 'schema_ref': 'Foo'}], - }, - 'default': None, - }, - }, - }, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Foo'), + [ + core_schema.typed_dict_schema( + { + 'height': core_schema.typed_dict_field(core_schema.int_schema()), + 'bar': core_schema.typed_dict_field(core_schema.definition_reference_schema('Bar')), }, - }, - }, - } + ref='Foo', + ), + core_schema.typed_dict_schema( + { + 'width': core_schema.typed_dict_field(core_schema.int_schema()), + 'bars': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.list_schema(core_schema.definition_reference_schema('Bar')), default=None + ) + ), + 'foo': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.union_schema( + [core_schema.none_schema(), core_schema.definition_reference_schema('Foo')] + ), + default=None, + ) + ), + }, + ref='Bar', + ), + ], + ) ) v.validate_python( { @@ -244,28 +203,28 @@ class Branch: branch: Optional['Branch'] v = SchemaValidator( - { - 'type': 'model', - 'ref': 'Branch', - 'cls': Branch, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'width': {'type': 'model-field', 'schema': {'type': 'int'}}, - 'branch': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'definition-ref', 'schema_ref': 'Branch'}], - }, - 'default': None, - }, - }, - }, - }, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.model_schema( + Branch, + core_schema.model_fields_schema( + { + 'width': core_schema.model_field(core_schema.int_schema()), + 'branch': core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [core_schema.none_schema(), core_schema.definition_reference_schema('Branch')] + ), + default=None, + ) + ), + } + ), + ref='Branch', + ) + ], + ) ) m1: Branch = v.validate_python({'width': '1'}) assert isinstance(m1, Branch) @@ -311,20 +270,19 @@ def test_invalid_schema(): def test_outside_parent(): v = SchemaValidator( - { - 'type': 'typed-dict', - 'fields': { - 'tuple1': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'str'}], - 'ref': 'tuple-iis', - }, - }, - 'tuple2': {'type': 'typed-dict-field', 'schema': {'type': 'definition-ref', 'schema_ref': 'tuple-iis'}}, - }, - } + core_schema.definitions_schema( + core_schema.typed_dict_schema( + { + 'tuple1': core_schema.typed_dict_field(core_schema.definition_reference_schema('tuple-iis')), + 'tuple2': core_schema.typed_dict_field(core_schema.definition_reference_schema('tuple-iis')), + } + ), + [ + core_schema.tuple_positional_schema( + [core_schema.int_schema(), core_schema.int_schema(), core_schema.str_schema()], ref='tuple-iis' + ) + ], + ) ) assert v.validate_python({'tuple1': [1, '1', 'frog'], 'tuple2': [2, '2', 'toad']}) == { @@ -335,21 +293,23 @@ def test_outside_parent(): def test_recursion_branch(): v = SchemaValidator( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'branch': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - 'default': None, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'name': core_schema.typed_dict_field(core_schema.str_schema()), + 'branch': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.nullable_schema(core_schema.definition_reference_schema('Branch')), + default=None, + ) + ), }, - }, - }, - }, + ref='Branch', + ) + ], + ), {'from_attributes': True}, ) assert ',definitions=[TypedDict(TypedDictValidator{' in plain_repr(v) @@ -377,21 +337,23 @@ def test_recursion_branch(): def test_recursion_branch_from_attributes(): v = SchemaValidator( - { - 'type': 'model-fields', - 'ref': 'Branch', - 'fields': { - 'name': {'type': 'model-field', 'schema': {'type': 'str'}}, - 'branch': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'Branch'}}, - 'default': None, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(core_schema.str_schema()), + 'branch': core_schema.model_field( + core_schema.with_default_schema( + core_schema.nullable_schema(core_schema.definition_reference_schema('Branch')), + default=None, + ) + ), }, - }, - }, - }, + ref='Branch', + ) + ], + ), {'from_attributes': True}, ) @@ -424,7 +386,10 @@ def test_recursion_branch_from_attributes(): def test_definition_list(): v = SchemaValidator( - {'type': 'list', 'ref': 'the-list', 'items_schema': {'type': 'definition-ref', 'schema_ref': 'the-list'}} + core_schema.definitions_schema( + core_schema.definition_reference_schema('the-list'), + [core_schema.list_schema(core_schema.definition_reference_schema('the-list'), ref='the-list')], + ) ) assert ',definitions=[List(ListValidator{' in plain_repr(v) assert v.validate_python([]) == [] @@ -448,30 +413,27 @@ def test_definition_list(): @pytest.fixture(scope='module') def multiple_tuple_schema() -> SchemaValidator: return SchemaValidator( - { - 'type': 'typed-dict', - 'fields': { - 'f1': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'tuple-positional', - 'items_schema': [ - {'type': 'int'}, - {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 't'}}, - ], - 'ref': 't', - }, - }, - 'f2': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 't'}}, - 'default': None, - }, - }, - }, - } + core_schema.definitions_schema( + core_schema.typed_dict_schema( + { + 'f1': core_schema.typed_dict_field(core_schema.definition_reference_schema('t')), + 'f2': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.nullable_schema(core_schema.definition_reference_schema('t')), default=None + ) + ), + } + ), + [ + core_schema.tuple_positional_schema( + [ + core_schema.int_schema(), + core_schema.nullable_schema(core_schema.definition_reference_schema('t')), + ], + ref='t', + ) + ], + ) ) @@ -554,18 +516,21 @@ def wrap_func(input_value, validator, info): return validator(input_value) + (42,) v = SchemaValidator( - { - 'type': 'function-wrap', - 'ref': 'wrapper', - 'function': {'type': 'general', 'function': wrap_func}, - 'schema': { - 'type': 'tuple-positional', - 'items_schema': [ - {'type': 'int'}, - {'type': 'nullable', 'schema': {'type': 'definition-ref', 'schema_ref': 'wrapper'}}, - ], - }, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('wrapper'), + [ + core_schema.general_wrap_validator_function( + wrap_func, + core_schema.tuple_positional_schema( + [ + core_schema.int_schema(), + core_schema.nullable_schema(core_schema.definition_reference_schema('wrapper')), + ] + ), + ref='wrapper', + ) + ], + ) ) assert v.validate_python((1, None)) == (1, None, 42) assert v.validate_python((1, (2, (3, None)))) == (1, (2, (3, None, 42), 42), 42) @@ -585,19 +550,19 @@ def wrap_func(input_value, validator, info): def test_union_ref_strictness(): v = SchemaValidator( - { - 'fields': { - 'a': {'type': 'typed-dict-field', 'schema': {'type': 'int', 'ref': 'int-type'}}, - 'b': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'definition-ref', 'schema_ref': 'int-type'}, {'type': 'str'}], - }, - }, - }, - 'type': 'typed-dict', - } + core_schema.definitions_schema( + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field(core_schema.definition_reference_schema('int-type')), + 'b': core_schema.typed_dict_field( + core_schema.union_schema( + [core_schema.definition_reference_schema('int-type'), core_schema.str_schema()] + ) + ), + } + ), + [core_schema.int_schema(ref='int-type')], + ) ) assert v.validate_python({'a': 1, 'b': '2'}) == {'a': 1, 'b': '2'} assert v.validate_python({'a': 1, 'b': 2}) == {'a': 1, 'b': 2} @@ -613,16 +578,19 @@ def test_union_ref_strictness(): def test_union_container_strictness(): v = SchemaValidator( - { - 'fields': { - 'b': { - 'type': 'typed-dict-field', - 'schema': {'type': 'union', 'choices': [{'type': 'int', 'ref': 'int-type'}, {'type': 'str'}]}, - }, - 'a': {'type': 'typed-dict-field', 'schema': {'type': 'definition-ref', 'schema_ref': 'int-type'}}, - }, - 'type': 'typed-dict', - } + core_schema.definitions_schema( + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.union_schema( + [core_schema.definition_reference_schema('int-type'), core_schema.str_schema()] + ) + ), + 'a': core_schema.typed_dict_field(core_schema.definition_reference_schema('int-type')), + } + ), + [core_schema.int_schema(ref='int-type')], + ) ) assert v.validate_python({'a': 1, 'b': '2'}) == {'a': 1, 'b': '2'} assert v.validate_python({'a': 1, 'b': 2}) == {'a': 1, 'b': 2} @@ -639,26 +607,25 @@ def test_union_container_strictness(): @pytest.mark.parametrize('strict', [True, False], ids=lambda s: f'strict={s}') def test_union_cycle(strict: bool): s = SchemaValidator( - { - 'choices': [ - { - 'fields': { - 'foobar': { - 'type': 'typed-dict-field', - 'schema': { - 'items_schema': {'schema_ref': 'root-schema', 'type': 'definition-ref'}, - 'type': 'list', - }, - } - }, - 'type': 'typed-dict', - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('root-schema'), + [ + core_schema.union_schema( + [ + core_schema.typed_dict_schema( + { + 'foobar': core_schema.typed_dict_field( + core_schema.list_schema(core_schema.definition_reference_schema('root-schema')) + ) + } + ) + ], + auto_collapse=False, + strict=strict, + ref='root-schema', + ) ], - 'auto_collapse': False, - 'strict': strict, - 'ref': 'root-schema', - 'type': 'union', - } + ) ) data = {'foobar': []} @@ -681,18 +648,20 @@ def f(input_value, info): return input_value + ' Changed' v = SchemaValidator( - { - 'choices': [ - { - 'type': 'function-after', - 'function': {'type': 'general', 'function': f}, - 'schema': {'schema_ref': 'root-schema', 'type': 'definition-ref'}, - }, - {'type': 'int'}, + core_schema.definitions_schema( + core_schema.definition_reference_schema('root-schema'), + [ + core_schema.union_schema( + [ + core_schema.general_after_validator_function( + f, core_schema.definition_reference_schema('root-schema') + ), + core_schema.int_schema(), + ], + ref='root-schema', + ) ], - 'ref': 'root-schema', - 'type': 'union', - } + ) ) assert v.validate_python(123) == 123 @@ -727,19 +696,21 @@ def f(input_value, info): return f'f-{int(count) + 1}' v = SchemaValidator( - { - 'choices': [ - { - 'type': 'function-before', - 'function': {'type': 'general', 'function': f}, - 'schema': {'schema_ref': 'root-schema', 'type': 'definition-ref'}, - } + core_schema.definitions_schema( + core_schema.definition_reference_schema('root-schema'), + [ + core_schema.union_schema( + [ + core_schema.general_before_validator_function( + f, core_schema.definition_reference_schema('root-schema') + ) + ], + auto_collapse=False, + strict=strict, + ref='root-schema', + ) ], - 'auto_collapse': False, - 'strict': strict, - 'ref': 'root-schema', - 'type': 'union', - } + ) ) with pytest.raises(ValidationError) as exc_info: @@ -758,23 +729,21 @@ def f(input_value, info): def test_many_uses_of_ref(): # check we can safely exceed RECURSION_GUARD_LIMIT without upsetting the recursion guard v = SchemaValidator( - { - 'type': 'typed-dict', - 'ref': 'Branch', - 'fields': { - 'name': { - 'type': 'typed-dict-field', - 'schema': {'type': 'str', 'max_length': 8, 'ref': 'limited-string'}, - }, - 'other_names': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'list', - 'items_schema': {'type': 'definition-ref', 'schema_ref': 'limited-string'}, + core_schema.definitions_schema( + core_schema.definition_reference_schema('Branch'), + [ + core_schema.typed_dict_schema( + { + 'name': core_schema.typed_dict_field(core_schema.definition_reference_schema('limited-string')), + 'other_names': core_schema.typed_dict_field( + core_schema.list_schema(core_schema.definition_reference_schema('limited-string')) + ), }, - }, - }, - } + ref='Branch', + ), + core_schema.str_schema(max_length=8, ref='limited-string'), + ], + ) ) assert v.validate_python({'name': 'Anne', 'other_names': ['Bob', 'Charlie']}) == { @@ -812,7 +781,8 @@ def test_error_inside_definition_wrapper(): } ) assert str(exc_info.value) == ( - 'Field "sub_branch":\n' + 'Error building "typed-dict" validator:\n' + ' SchemaError: Field "sub_branch":\n' ' SchemaError: Error building "default" validator:\n' " SchemaError: 'default' and 'default_factory' cannot be used together" ) @@ -820,7 +790,7 @@ def test_error_inside_definition_wrapper(): def test_recursive_definitions_schema(pydantic_version) -> None: s = core_schema.definitions_schema( - core_schema.definition_reference_schema(schema_ref='a'), + core_schema.definition_reference_schema('a'), [ core_schema.typed_dict_schema( { @@ -861,7 +831,7 @@ def test_recursive_definitions_schema(pydantic_version) -> None: def test_unsorted_definitions_schema() -> None: s = core_schema.definitions_schema( - core_schema.definition_reference_schema(schema_ref='td'), + core_schema.definition_reference_schema('td'), [ core_schema.typed_dict_schema( {'x': core_schema.typed_dict_field(core_schema.definition_reference_schema('int'))}, ref='td' @@ -883,22 +853,28 @@ def test_validate_assignment(pydantic_version) -> None: class Model: x: List['Model'] - schema = core_schema.dataclass_schema( - Model, - core_schema.dataclass_args_schema( - 'Model', - [ - core_schema.dataclass_field( - name='x', - schema=core_schema.list_schema(core_schema.definition_reference_schema('model')), - kw_only=False, - ) - ], - ), - ['x'], - ref='model', - config=core_schema.CoreConfig(revalidate_instances='always'), + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('model'), + [ + core_schema.dataclass_schema( + Model, + core_schema.dataclass_args_schema( + 'Model', + [ + core_schema.dataclass_field( + name='x', + schema=core_schema.list_schema(core_schema.definition_reference_schema('model')), + kw_only=False, + ) + ], + ), + ['x'], + ref='model', + config=core_schema.CoreConfig(revalidate_instances='always'), + ) + ], ) + v = SchemaValidator(schema) data = [Model(x=[Model(x=[])])] diff --git a/tests/validators/test_tagged_union.py b/tests/validators/test_tagged_union.py index e2fac95af..3061b2c34 100644 --- a/tests/validators/test_tagged_union.py +++ b/tests/validators/test_tagged_union.py @@ -3,7 +3,7 @@ import pytest from dirty_equals import IsAnyStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import CoreConfig, SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson from .test_typed_dict import Cls @@ -495,23 +495,20 @@ def test_from_attributes(): def test_use_ref(): v = SchemaValidator( - { - 'type': 'tagged-union', - 'discriminator': 'foobar', - 'choices': { - 'apple': { - 'type': 'typed-dict', - 'ref': 'apple', - 'fields': {'a': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}}, - }, - 'apple2': {'type': 'definition-ref', 'schema_ref': 'apple'}, - 'banana': { - 'type': 'typed-dict', - 'fields': {'b': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}}, + core_schema.definitions_schema( + core_schema.tagged_union_schema( + discriminator='foobar', + choices={ + 'apple': core_schema.definition_reference_schema('apple'), + 'apple2': core_schema.definition_reference_schema('apple'), + 'banana': core_schema.typed_dict_schema( + {'b': core_schema.typed_dict_field(core_schema.str_schema())} + ), }, - }, - }, - {'from_attributes': True}, + ), + [core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(core_schema.str_schema())}, ref='apple')], + ), + config=CoreConfig(from_attributes=True), ) assert v.validate_python({'foobar': 'apple', 'a': 'apple'}) == {'a': 'apple'} assert v.validate_python({'foobar': 'apple2', 'a': 'apple'}) == {'a': 'apple'} diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index d4ceb4b9f..7e0f608f1 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -396,14 +396,19 @@ def test_no_strict_check(): def test_strict_reference(): v = SchemaValidator( - core_schema.tuple_positional_schema( + core_schema.definitions_schema( + core_schema.definition_reference_schema(schema_ref='tuple-ref'), [ - core_schema.float_schema(), - core_schema.union_schema( - [core_schema.int_schema(), core_schema.definition_reference_schema('tuple-ref')] - ), + core_schema.tuple_positional_schema( + [ + core_schema.float_schema(), + core_schema.union_schema( + [core_schema.int_schema(), core_schema.definition_reference_schema('tuple-ref')] + ), + ], + ref='tuple-ref', + ) ], - ref='tuple-ref', ) ) assert 'strict_required:true' in plain_repr(v) From fb7b50133a8bbf1bf5fcd5dabd94b62f2b58ff6a Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:37:26 +0100 Subject: [PATCH 039/550] fix function-after validator changing validation mode to Python (#967) --- src/validators/function.rs | 12 ++++++------ tests/validators/test_function.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/validators/function.rs b/src/validators/function.rs index 206e10a0c..9f8fe75a7 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -100,7 +100,7 @@ macro_rules! impl_validator { state: &mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { let validate = |v, s: &mut ValidationState<'_>| self.validator.validate(py, v, s); - self._validate(validate, py, input.to_object(py).into_ref(py), state) + self._validate(validate, py, input, state) } fn validate_assignment<'data>( &self, @@ -158,7 +158,7 @@ impl FunctionBeforeValidator { &'s self, call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, py: Python<'data>, - input: &'data PyAny, + input: &'data impl Input<'data>, state: &'s mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { let r = if self.info_arg { @@ -187,11 +187,11 @@ pub struct FunctionAfterValidator { impl_build!(FunctionAfterValidator, "function-after"); impl FunctionAfterValidator { - fn _validate<'s, 'data>( + fn _validate<'s, 'data, I: Input<'data>>( &'s self, - call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, + call: impl FnOnce(&'data I, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, py: Python<'data>, - input: &'data PyAny, + input: &'data I, state: &mut ValidationState<'_>, ) -> ValResult<'data, PyObject> { let v = call(input, state)?; @@ -326,7 +326,7 @@ impl FunctionWrapValidator { &'s self, handler: &'s PyAny, py: Python<'data>, - input: &'data PyAny, + input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let r = if self.info_arg { diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index a4f9737f7..07710be41 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -1,3 +1,4 @@ +import datetime import platform import re from copy import deepcopy @@ -971,3 +972,21 @@ def __repr__(self) -> str: 'ValidationInfo(config=None, context=None)', "FieldValidationInfo(config=None, context=None, field_name='x')", ] + + +def test_function_after_doesnt_change_mode() -> None: + # https://github.com/pydantic/pydantic/issues/7468 - function-after was + # incorrectly forcing Python validation mode + + def identity(v): + return v + + schema = core_schema.no_info_after_validator_function(identity, core_schema.date_schema(strict=True)) + v = SchemaValidator(schema) + + # this input should be valid JSON input, but isn't valid Python input, so + # the following tests will pass if the after_validator is not + # forcing the mode to Python + assert v.validate_json(b'"2000-01-01"') == datetime.date(2000, 1, 1) + with pytest.raises(ValidationError): + v.validate_python(b'"2000-01-01"') From 367a67aaa17c2e0a4daf086250c265f2a61ee499 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 19 Sep 2023 16:42:09 +0100 Subject: [PATCH 040/550] implementing `validate_strings` (#883) Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- python/pydantic_core/_pydantic_core.pyi | 29 ++- src/build_tools.rs | 5 +- src/errors/mod.rs | 2 +- src/errors/types.rs | 41 ++--- src/errors/validation_exception.rs | 47 +++-- src/errors/value_exception.rs | 5 +- src/input/input_abstract.rs | 37 +++- src/input/input_json.rs | 127 ++++--------- src/input/input_python.rs | 15 +- src/input/input_string.rs | 229 ++++++++++++++++++++++++ src/input/mod.rs | 8 +- src/input/return_enums.rs | 35 ++++ src/input/shared.rs | 8 +- src/lookup_key.rs | 39 +++- src/validators/arguments.rs | 1 + src/validators/dataclass.rs | 52 ++++-- src/validators/dict.rs | 24 ++- src/validators/function.rs | 2 +- src/validators/generator.rs | 16 +- src/validators/json_or_python.rs | 4 +- src/validators/mod.rs | 52 ++++-- src/validators/model_fields.rs | 61 ++++--- src/validators/typed_dict.rs | 67 ++++--- src/validators/union.rs | 3 +- src/validators/validation_state.rs | 1 + tests/conftest.py | 3 + tests/test_validate_strings.py | 121 +++++++++++++ tests/validators/test_bool.py | 3 +- tests/validators/test_float.py | 3 +- tests/validators/test_int.py | 3 +- tests/validators/test_json.py | 13 +- 31 files changed, 777 insertions(+), 279 deletions(-) create mode 100644 src/input/input_string.rs create mode 100644 tests/test_validate_strings.py diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 9b24e7076..be2b64793 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -52,6 +52,8 @@ _recursion_limit: int _T = TypeVar('_T', default=Any, covariant=True) +_StringInput: TypeAlias = 'dict[str, _StringInput]' + @final class Some(Generic[_T]): """ @@ -168,6 +170,29 @@ class SchemaValidator: ValidationError: If validation fails or if the JSON data is invalid. Exception: Other error types maybe raised if internal errors occur. + Returns: + The validated Python object. + """ + def validate_strings( + self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None + ) -> Any: + """ + Validate a string against the schema and return the validated Python object. + + This is similar to `validate_json` but applies to scenarios where the input will be a string but not + JSON data, e.g. URL fragments, query parameters, etc. + + Arguments: + input: The input as a string, or bytes/bytearray if `strict=False`. + strict: Whether to validate the object in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + + Raises: + ValidationError: If validation fails or if the JSON data is invalid. + Exception: Other error types maybe raised if internal errors occur. + Returns: The validated Python object. """ @@ -680,7 +705,7 @@ class ValidationError(ValueError): def from_exception_data( title: str, line_errors: list[InitErrorDetails], - error_mode: Literal['python', 'json'] = 'python', + input_type: Literal['python', 'json'] = 'python', hide_input: bool = False, ) -> ValidationError: """ @@ -693,7 +718,7 @@ class ValidationError(ValueError): title: The title of the error, as used in the heading of `str(validation_error)` line_errors: A list of [`InitErrorDetails`][pydantic_core.InitErrorDetails] which contain information about errors that occurred during validation. - error_mode: Whether the error is for a Python object or JSON. + input_type: Whether the error is for a Python object or JSON. hide_input: Whether to hide the input value in the error message. """ @property diff --git a/src/build_tools.rs b/src/build_tools.rs index d2bc7c2cd..47fa569ac 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -6,7 +6,8 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; use pyo3::{intern, FromPyObject, PyErrArguments}; -use crate::errors::{ErrorMode, ValError}; +use crate::errors::ValError; +use crate::input::InputType; use crate::tools::SchemaDict; use crate::ValidationError; @@ -86,7 +87,7 @@ impl SchemaError { ValError::LineErrors(raw_errors) => { let line_errors = raw_errors.into_iter().map(|e| e.into_py(py)).collect(); let validation_error = - ValidationError::new(line_errors, "Schema".to_object(py), ErrorMode::Python, false); + ValidationError::new(line_errors, "Schema".to_object(py), InputType::Python, false); let schema_error = SchemaError(SchemaErrorEnum::ValidationError(validation_error)); match Py::new(py, schema_error) { Ok(err) => PyErr::from_value(err.into_ref(py)), diff --git a/src/errors/mod.rs b/src/errors/mod.rs index ed10049de..6a253197f 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -8,7 +8,7 @@ mod value_exception; pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; pub use self::location::LocItem; -pub use self::types::{list_all_errors, ErrorMode, ErrorType, ErrorTypeDefaults, Number}; +pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/errors/types.rs b/src/errors/types.rs index d7e1051b7..da4d5fdd7 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -2,37 +2,20 @@ use std::any::type_name; use std::borrow::Cow; use std::fmt; -use ahash::AHashMap; -use num_bigint::BigInt; -use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyKeyError, PyTypeError}; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; -use crate::input::Int; -use crate::tools::{extract_i64, py_err, py_error_type}; +use ahash::AHashMap; +use num_bigint::BigInt; use strum::{Display, EnumMessage, IntoEnumIterator}; use strum_macros::EnumIter; -use super::PydanticCustomError; - -#[derive(Clone, Debug)] -pub enum ErrorMode { - Python, - Json, -} - -impl TryFrom<&str> for ErrorMode { - type Error = PyErr; +use crate::input::{InputType, Int}; +use crate::tools::{extract_i64, py_err, py_error_type}; - fn try_from(error_mode: &str) -> PyResult { - match error_mode { - "python" => Ok(Self::Python), - "json" => Ok(Self::Json), - s => py_err!(PyValueError; "Invalid error mode: {}", s), - } - } -} +use super::PydanticCustomError; #[pyfunction] pub fn list_all_errors(py: Python) -> PyResult<&PyList> { @@ -45,12 +28,12 @@ pub fn list_all_errors(py: Python) -> PyResult<&PyList> { d.set_item("message_template_python", message_template_python)?; d.set_item( "example_message_python", - error_type.render_message(py, &ErrorMode::Python)?, + error_type.render_message(py, InputType::Python)?, )?; let message_template_json = error_type.message_template_json(); if message_template_python != message_template_json { d.set_item("message_template_json", message_template_json)?; - d.set_item("example_message_json", error_type.render_message(py, &ErrorMode::Json)?)?; + d.set_item("example_message_json", error_type.render_message(py, InputType::Json)?)?; } d.set_item("example_context", error_type.py_dict(py)?)?; errors.push(d); @@ -623,10 +606,10 @@ impl ErrorType { } } - pub fn render_message(&self, py: Python, error_mode: &ErrorMode) -> PyResult { - let tmpl = match error_mode { - ErrorMode::Python => self.message_template_python(), - ErrorMode::Json => self.message_template_json(), + pub fn render_message(&self, py: Python, input_type: InputType) -> PyResult { + let tmpl = match input_type { + InputType::Python => self.message_template_python(), + _ => self.message_template_json(), }; match self { Self::NoSuchAttribute { attribute, .. } => render!(tmpl, attribute), diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index e6563f597..d0977001a 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -16,12 +16,13 @@ use serde_json::ser::PrettyFormatter; use crate::build_tools::py_schema_error_type; use crate::errors::LocItem; use crate::get_pydantic_version; +use crate::input::InputType; use crate::serializers::{SerMode, SerializationState}; use crate::tools::{safe_repr, SchemaDict}; use super::line_error::ValLineError; use super::location::Location; -use super::types::{ErrorMode, ErrorType}; +use super::types::ErrorType; use super::value_exception::PydanticCustomError; use super::{InputValue, ValError}; @@ -31,16 +32,16 @@ use super::{InputValue, ValError}; pub struct ValidationError { line_errors: Vec, title: PyObject, - error_mode: ErrorMode, + input_type: InputType, hide_input: bool, } impl ValidationError { - pub fn new(line_errors: Vec, title: PyObject, error_mode: ErrorMode, hide_input: bool) -> Self { + pub fn new(line_errors: Vec, title: PyObject, input_type: InputType, hide_input: bool) -> Self { Self { line_errors, title, - error_mode, + input_type, hide_input, } } @@ -48,7 +49,7 @@ impl ValidationError { pub fn from_val_error( py: Python, title: PyObject, - error_mode: ErrorMode, + input_type: InputType, error: ValError, outer_location: Option, hide_input: bool, @@ -63,9 +64,7 @@ impl ValidationError { .collect(), None => raw_errors.into_iter().map(|e| e.into_py(py)).collect(), }; - - let validation_error = Self::new(line_errors, title, error_mode, hide_input); - + let validation_error = Self::new(line_errors, title, input_type, hide_input); match Py::new(py, validation_error) { Ok(err) => { if validation_error_cause { @@ -87,7 +86,7 @@ impl ValidationError { pub fn display(&self, py: Python, prefix_override: Option<&'static str>, hide_input: bool) -> String { let url_prefix = get_url_prefix(py, include_url_env(py)); - let line_errors = pretty_py_line_errors(py, &self.error_mode, self.line_errors.iter(), url_prefix, hide_input); + let line_errors = pretty_py_line_errors(py, self.input_type, self.line_errors.iter(), url_prefix, hide_input); if let Some(prefix) = prefix_override { format!("{prefix}\n{line_errors}") } else { @@ -238,12 +237,12 @@ impl ValidationError { #[pymethods] impl ValidationError { #[staticmethod] - #[pyo3(signature = (title, line_errors, error_mode="python", hide_input=false))] + #[pyo3(signature = (title, line_errors, input_type="python", hide_input=false))] fn from_exception_data( py: Python, title: PyObject, line_errors: &PyList, - error_mode: &str, + input_type: &str, hide_input: bool, ) -> PyResult> { Py::new( @@ -251,7 +250,7 @@ impl ValidationError { Self { line_errors: line_errors.iter().map(PyLineError::try_from).collect::>()?, title, - error_mode: ErrorMode::try_from(error_mode)?, + input_type: InputType::try_from(input_type)?, hide_input, }, ) @@ -279,7 +278,7 @@ impl ValidationError { if iteration_error.is_some() { return py.None(); } - e.as_dict(py, url_prefix, include_context, &self.error_mode) + e.as_dict(py, url_prefix, include_context, self.input_type) .unwrap_or_else(|err| { iteration_error = Some(err); py.None() @@ -309,7 +308,7 @@ impl ValidationError { url_prefix: get_url_prefix(py, include_url), include_context, extra: &extra, - error_mode: &self.error_mode, + input_type: &self.input_type, }; let writer: Vec = Vec::with_capacity(self.line_errors.len() * 200); @@ -387,13 +386,13 @@ macro_rules! truncate_input_value { pub fn pretty_py_line_errors<'a>( py: Python, - error_mode: &ErrorMode, + input_type: InputType, line_errors_iter: impl Iterator, url_prefix: Option<&str>, hide_input: bool, ) -> String { line_errors_iter - .map(|i| i.pretty(py, error_mode, url_prefix, hide_input)) + .map(|i| i.pretty(py, input_type, url_prefix, hide_input)) .collect::, _>>() .unwrap_or_else(|err| vec![format!("[error formatting line errors: {err}]")]) .join("\n") @@ -477,12 +476,12 @@ impl PyLineError { py: Python, url_prefix: Option<&str>, include_context: bool, - error_mode: &ErrorMode, + input_type: InputType, ) -> PyResult { let dict = PyDict::new(py); dict.set_item("type", self.error_type.type_string())?; dict.set_item("loc", self.location.to_object(py))?; - dict.set_item("msg", self.error_type.render_message(py, error_mode)?)?; + dict.set_item("msg", self.error_type.render_message(py, input_type)?)?; dict.set_item("input", &self.input_value)?; if include_context { if let Some(context) = self.error_type.py_dict(py)? { @@ -505,14 +504,14 @@ impl PyLineError { fn pretty( &self, py: Python, - error_mode: &ErrorMode, + input_type: InputType, url_prefix: Option<&str>, hide_input: bool, ) -> Result { let mut output = String::with_capacity(200); write!(output, "{}", self.location)?; - let message = match self.error_type.render_message(py, error_mode) { + let message = match self.error_type.render_message(py, input_type) { Ok(message) => message, Err(err) => format!("(error rendering message: {err})"), }; @@ -565,7 +564,7 @@ struct ValidationErrorSerializer<'py> { url_prefix: Option<&'py str>, include_context: bool, extra: &'py crate::serializers::Extra<'py>, - error_mode: &'py ErrorMode, + input_type: &'py InputType, } impl<'py> Serialize for ValidationErrorSerializer<'py> { @@ -581,7 +580,7 @@ impl<'py> Serialize for ValidationErrorSerializer<'py> { url_prefix: self.url_prefix, include_context: self.include_context, extra: self.extra, - error_mode: self.error_mode, + input_type: self.input_type, }; seq.serialize_element(&line_s)?; } @@ -595,7 +594,7 @@ struct PyLineErrorSerializer<'py> { url_prefix: Option<&'py str>, include_context: bool, extra: &'py crate::serializers::Extra<'py>, - error_mode: &'py ErrorMode, + input_type: &'py InputType, } impl<'py> Serialize for PyLineErrorSerializer<'py> { @@ -620,7 +619,7 @@ impl<'py> Serialize for PyLineErrorSerializer<'py> { let msg = self .line_error .error_type - .render_message(py, self.error_mode) + .render_message(py, *self.input_type) .map_err(py_err_json::)?; map.serialize_entry("msg", &msg)?; diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index d0c08bf5f..f7d877b30 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -1,9 +1,8 @@ -use crate::errors::ErrorMode; use pyo3::exceptions::{PyException, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; -use crate::input::Input; +use crate::input::{Input, InputType}; use crate::tools::extract_i64; use super::{ErrorType, ValError}; @@ -164,7 +163,7 @@ impl PydanticKnownError { } pub fn message(&self, py: Python) -> PyResult { - self.error_type.render_message(py, &ErrorMode::Python) + self.error_type.render_message(py, InputType::Python) } fn __str__(&self, py: Python) -> PyResult { diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index d799da473..f4a760a45 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -1,9 +1,11 @@ use std::fmt; +use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; use crate::errors::{InputValue, LocItem, ValResult}; +use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; @@ -14,6 +16,7 @@ use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, Gen pub enum InputType { Python, Json, + String, } impl IntoPy for InputType { @@ -21,6 +24,20 @@ impl IntoPy for InputType { match self { Self::Json => intern!(py, "json").into(), Self::Python => intern!(py, "python").into(), + Self::String => intern!(py, "string").into(), + } + } +} + +impl TryFrom<&str> for InputType { + type Error = PyErr; + + fn try_from(error_mode: &str) -> PyResult { + match error_mode { + "python" => Ok(Self::Python), + "json" => Ok(Self::Json), + "string" => Ok(Self::String), + s => py_err!(PyValueError; "Invalid error mode: {}", s), } } } @@ -38,7 +55,9 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { None } - fn is_none(&self) -> bool; + fn is_none(&self) -> bool { + false + } fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> { None @@ -320,3 +339,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { self.strict_timedelta(microseconds_overflow_behavior) } } + +/// The problem to solve here is that iterating a `StringMapping` returns an owned +/// `StringMapping`, but all the other iterators return references. By introducing +/// this trait we abstract over whether the return value from the iterator is owned +/// or borrowed; all we care about is that we can borrow it again with `borrow_input` +/// for some lifetime 'a. +/// +/// This lifetime `'a` is shorter than the original lifetime `'data` of the input, +/// which is only a problem in error branches. To resolve we have to call `into_owned` +/// to extend out the lifetime to match the original input. +pub trait BorrowInput { + type Input<'a>: Input<'a> + where + Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_>; +} diff --git a/src/input/input_json.rs b/src/input/input_json.rs index d948e2493..07f3554e6 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -12,11 +12,10 @@ use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::parse_json::JsonArray; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int, string_to_vec}; use super::{ - EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonArgs, JsonInput, + BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, + GenericIterator, GenericMapping, Input, JsonArgs, JsonArray, JsonInput, }; impl<'a> Input<'a> for JsonInput { @@ -355,6 +354,15 @@ impl<'a> Input<'a> for JsonInput { } } +impl BorrowInput for &'_ JsonInput { + type Input<'a> = JsonInput where Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_> { + self + } +} + +/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this +/// implementation /// Required for Dict keys so the string can behave like an Input impl<'a> Input<'a> for String { fn as_loc_item(&self) -> LocItem { @@ -365,11 +373,6 @@ impl<'a> Input<'a> for String { InputValue::String(self) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn is_none(&self) -> bool { - false - } - fn as_kwargs(&'a self, _py: Python<'a>) -> Option<&'a PyDict> { None } @@ -395,47 +398,29 @@ impl<'a> Input<'a> for String { serde_json::from_str(self.as_str()).map_err(|e| map_json_err(self, e)) } - fn validate_str(&'a self, _strict: bool) -> ValResult> { - Ok(self.as_str().into()) - } fn strict_str(&'a self) -> ValResult> { - self.validate_str(false) + Ok(self.as_str().into()) } - fn validate_bytes(&'a self, _strict: bool) -> ValResult> { - Ok(self.as_bytes().into()) - } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_bytes(&'a self) -> ValResult> { - self.validate_bytes(false) + Ok(self.as_bytes().into()) } fn strict_bool(&self) -> ValResult { - Err(ValError::new(ErrorTypeDefaults::BoolType, self)) - } - fn lax_bool(&self) -> ValResult { str_as_bool(self, self) } fn strict_int(&'a self) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - fn lax_int(&'a self) -> ValResult> { match self.parse() { Ok(i) => Ok(EitherInt::I64(i)), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), } } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn ultra_strict_float(&'a self) -> ValResult> { self.strict_float() } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_float(&'a self) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } - fn lax_float(&'a self) -> ValResult> { str_as_float(self, self) } @@ -443,49 +428,29 @@ impl<'a> Input<'a> for String { create_decimal(self.to_object(py).into_ref(py), self, py) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_dict(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::DictType, self)) - } #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_dict(&'a self) -> ValResult> { - self.validate_dict(false) + Err(ValError::new(ErrorTypeDefaults::DictType, self)) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_list(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::ListType, self)) - } #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_list(&'a self) -> ValResult> { - self.validate_list(false) + Err(ValError::new(ErrorTypeDefaults::ListType, self)) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_tuple(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::TupleType, self)) - } #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_tuple(&'a self) -> ValResult> { - self.validate_tuple(false) + Err(ValError::new(ErrorTypeDefaults::TupleType, self)) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_set(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::SetType, self)) - } #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_set(&'a self) -> ValResult> { - self.validate_set(false) + Err(ValError::new(ErrorTypeDefaults::SetType, self)) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { - Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)) - } #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_frozenset(&'a self) -> ValResult> { - self.validate_frozenset(false) + Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)) } fn extract_generic_iterable(&'a self) -> ValResult> { @@ -496,60 +461,42 @@ impl<'a> Input<'a> for String { Ok(string_to_vec(self).into()) } - fn validate_date(&self, _strict: bool) -> ValResult { - bytes_as_date(self, self.as_bytes()) - } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_date(&self) -> ValResult { - self.validate_date(false) + bytes_as_date(self, self.as_bytes()) } - fn validate_time( - &self, - _strict: bool, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior) - } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_time( &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { - self.validate_time(false, microseconds_overflow_behavior) + bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior) } - fn validate_datetime( - &self, - _strict: bool, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior) - } - #[cfg_attr(has_coverage_attribute, coverage(off))] fn strict_datetime( &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { - self.validate_datetime(false, microseconds_overflow_behavior) + bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior) } - fn validate_timedelta( + fn strict_timedelta( &self, - _strict: bool, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior) } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_timedelta( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.validate_timedelta(false, microseconds_overflow_behavior) +} + +impl BorrowInput for &'_ String { + type Input<'a> = String where Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_> { + self } } -fn string_to_vec(s: &str) -> JsonArray { - JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) +impl BorrowInput for String { + type Input<'a> = String where Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_> { + self + } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 3fbf20240..9e5fd1d1f 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -23,7 +23,7 @@ use super::datetime::{ }; use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ - py_string_str, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, + py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, }; @@ -184,7 +184,7 @@ impl<'a> Input<'a> for PyAny { if let Ok(py_bytes) = self.downcast::() { serde_json::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e)) } else if let Ok(py_str) = self.downcast::() { - let str = py_str.to_str()?; + let str = py_string_str(py_str)?; serde_json::from_str(str).map_err(|e| map_json_err(self, e)) } else if let Ok(py_byte_array) = self.downcast::() { // Safety: from_slice does not run arbitrary Python code and the GIL is held so the @@ -196,7 +196,7 @@ impl<'a> Input<'a> for PyAny { } fn strict_str(&'a self) -> ValResult> { - if let Ok(py_str) = ::try_from_exact(self) { + if let Ok(py_str) = PyString::try_from_exact(self) { Ok(py_str.into()) } else if let Ok(py_str) = self.downcast::() { // force to a rust string to make sure behavior is consistent whether or not we go via a @@ -208,7 +208,7 @@ impl<'a> Input<'a> for PyAny { } fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = ::try_from_exact(self) { + if let Ok(py_str) = PyString::try_from_exact(self) { Ok(EitherString::Py(py_str)) } else { Err(ValError::new(ErrorTypeDefaults::IntType, self)) @@ -710,6 +710,13 @@ impl<'a> Input<'a> for PyAny { } } +impl BorrowInput for &'_ PyAny { + type Input<'a> = PyAny where Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_> { + self + } +} + /// Best effort check of whether it's likely to make sense to inspect obj for attributes and iterate over it /// with `obj.dir()` fn from_attributes_applicable(obj: &PyAny) -> bool { diff --git a/src/input/input_string.rs b/src/input/input_string.rs new file mode 100644 index 000000000..72a32d897 --- /dev/null +++ b/src/input/input_string.rs @@ -0,0 +1,229 @@ +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyString}; + +use speedate::MicrosecondsPrecisionOverflowBehavior; + +use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::input::py_string_str; +use crate::tools::safe_repr; +use crate::validators::decimal::create_decimal; + +use super::datetime::{ + bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, +}; +use super::shared::{map_json_err, str_as_bool, str_as_float}; +use super::{ + BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, + GenericIterator, GenericMapping, Input, JsonInput, +}; + +#[derive(Debug)] +pub enum StringMapping<'py> { + String(&'py PyString), + Mapping(&'py PyDict), +} + +impl<'py> ToPyObject for StringMapping<'py> { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + Self::String(s) => s.to_object(py), + Self::Mapping(d) => d.to_object(py), + } + } +} + +impl<'py> StringMapping<'py> { + pub fn new_key(py_key: &'py PyAny) -> ValResult<'py, StringMapping> { + if let Ok(py_str) = py_key.downcast::() { + Ok(Self::String(py_str)) + } else { + Err(ValError::new(ErrorTypeDefaults::StringType, py_key)) + } + } + + pub fn new_value(py_value: &'py PyAny) -> ValResult<'py, Self> { + if let Ok(py_str) = py_value.downcast::() { + Ok(Self::String(py_str)) + } else if let Ok(value) = py_value.downcast::() { + Ok(Self::Mapping(value)) + } else { + Err(ValError::new(ErrorTypeDefaults::StringType, py_value)) + } + } +} + +impl<'a> Input<'a> for StringMapping<'a> { + fn as_loc_item(&self) -> LocItem { + match self { + Self::String(s) => s.to_string_lossy().as_ref().into(), + Self::Mapping(d) => safe_repr(d).to_string().into(), + } + } + + fn as_error_value(&'a self) -> InputValue<'a> { + match self { + Self::String(s) => s.as_error_value(), + Self::Mapping(d) => InputValue::PyAny(d), + } + } + + fn as_kwargs(&'a self, _py: Python<'a>) -> Option<&'a PyDict> { + None + } + + fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { + // do we want to support this? + Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) + } + + fn validate_dataclass_args(&'a self, _dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>> { + match self { + StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), + StringMapping::Mapping(m) => Ok(GenericArguments::StringMapping(m)), + } + } + + fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + match self { + Self::String(s) => { + let str = py_string_str(s)?; + serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + } + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), + } + } + + fn strict_str(&'a self) -> ValResult> { + match self { + Self::String(s) => Ok((*s).into()), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::StringType, self)), + } + } + + fn strict_bytes(&'a self) -> ValResult> { + match self { + Self::String(s) => py_string_str(s).map(|b| b.as_bytes().into()), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), + } + } + + fn lax_bytes(&'a self) -> ValResult> { + match self { + Self::String(s) => { + let str = py_string_str(s)?; + Ok(str.as_bytes().into()) + } + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), + } + } + + fn strict_bool(&self) -> ValResult { + match self { + Self::String(s) => str_as_bool(self, py_string_str(s)?), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), + } + } + + fn strict_int(&'a self) -> ValResult> { + match self { + Self::String(s) => match py_string_str(s)?.parse() { + Ok(i) => Ok(EitherInt::I64(i)), + Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), + }, + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::IntType, self)), + } + } + + fn ultra_strict_float(&'a self) -> ValResult> { + self.strict_float() + } + + fn strict_float(&'a self) -> ValResult> { + match self { + Self::String(s) => str_as_float(self, py_string_str(s)?), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), + } + } + + fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { + match self { + Self::String(s) => create_decimal(s, self, py), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), + } + } + + fn strict_dict(&'a self) -> ValResult> { + match self { + Self::String(_) => Err(ValError::new(ErrorTypeDefaults::DictType, self)), + Self::Mapping(d) => Ok(GenericMapping::StringMapping(d)), + } + } + + fn strict_list(&'a self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::ListType, self)) + } + + fn strict_tuple(&'a self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::TupleType, self)) + } + + fn strict_set(&'a self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::SetType, self)) + } + + fn strict_frozenset(&'a self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)) + } + + fn extract_generic_iterable(&'a self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::IterableType, self)) + } + + fn validate_iter(&self) -> ValResult { + Err(ValError::new(ErrorTypeDefaults::IterableType, self)) + } + + fn strict_date(&self) -> ValResult { + match self { + Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)), + } + } + + fn strict_time( + &self, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult { + match self { + Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), + } + } + + fn strict_datetime( + &self, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult { + match self { + Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), + } + } + + fn strict_timedelta( + &self, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult { + match self { + Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), + } + } +} + +impl BorrowInput for StringMapping<'_> { + type Input<'a> = StringMapping<'a> where Self: 'a; + fn borrow_input(&self) -> &Self::Input<'_> { + self + } +} diff --git a/src/input/mod.rs b/src/input/mod.rs index c15f54b2a..22d774a8c 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -6,6 +6,7 @@ mod datetime; mod input_abstract; mod input_json; mod input_python; +mod input_string; mod parse_json; mod return_enums; mod shared; @@ -15,12 +16,13 @@ pub(crate) use datetime::{ duration_as_pytimedelta, pydate_as_date, pydatetime_as_datetime, pytime_as_time, EitherDate, EitherDateTime, EitherTime, EitherTimedelta, }; -pub(crate) use input_abstract::{Input, InputType}; -pub(crate) use parse_json::{JsonInput, JsonObject}; +pub(crate) use input_abstract::{BorrowInput, Input, InputType}; +pub(crate) use input_string::StringMapping; +pub(crate) use parse_json::{JsonArray, JsonInput, JsonObject}; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, - MappingGenericIterator, PyArgs, + MappingGenericIterator, PyArgs, StringMappingGenericIterator, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index e97fe8b81..56f86f580 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -25,6 +25,7 @@ use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, Val use crate::tools::py_err; use crate::validators::{CombinedValidator, ValidationState, Validator}; +use super::input_string::StringMapping; use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::{py_error_on_minusone, Input}; @@ -429,6 +430,7 @@ impl<'a> GenericIterable<'a> { pub enum GenericMapping<'a> { PyDict(&'a PyDict), PyMapping(&'a PyMapping), + StringMapping(&'a PyDict), PyGetAttr(&'a PyAny, Option<&'a PyDict>), JsonObject(&'a JsonObject), } @@ -506,6 +508,38 @@ impl<'py> Iterator for MappingGenericIterator<'py> { } } +pub struct StringMappingGenericIterator<'py> { + dict_iter: PyDictIterator<'py>, +} + +impl<'py> StringMappingGenericIterator<'py> { + pub fn new(dict: &'py PyDict) -> ValResult<'py, Self> { + Ok(Self { dict_iter: dict.iter() }) + } +} + +impl<'py> Iterator for StringMappingGenericIterator<'py> { + // key (the first member of the tuple could be a simple String) + type Item = ValResult<'py, (StringMapping<'py>, StringMapping<'py>)>; + + fn next(&mut self) -> Option { + match self.dict_iter.next() { + Some((py_key, py_value)) => { + let key = match StringMapping::new_key(py_key) { + Ok(key) => key, + Err(e) => return Some(Err(e)), + }; + let value = match StringMapping::new_value(py_value) { + Ok(value) => value, + Err(e) => return Some(Err(e)), + }; + Some(Ok((key, value))) + } + None => None, + } + } +} + pub struct AttributesGenericIterator<'py> { object: &'py PyAny, // PyO3 should export this type upstream @@ -691,6 +725,7 @@ impl<'a> JsonArgs<'a> { pub enum GenericArguments<'a> { Py(PyArgs<'a>), Json(JsonArgs<'a>), + StringMapping(&'a PyDict), } impl<'a> From> for GenericArguments<'a> { diff --git a/src/input/shared.rs b/src/input/shared.rs index d8733bd31..1a8e2b61c 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -2,9 +2,9 @@ use num_bigint::BigInt; use pyo3::{intern, PyAny, Python}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; -use crate::input::EitherInt; -use super::{EitherFloat, Input}; +use super::parse_json::{JsonArray, JsonInput}; +use super::{EitherFloat, EitherInt, Input}; pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { ValError::new( @@ -150,3 +150,7 @@ pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a Py } Ok(EitherInt::Py(numerator)) } + +pub fn string_to_vec(s: &str) -> JsonArray { + JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) +} diff --git a/src/lookup_key.rs b/src/lookup_key.rs index 49915e3b4..36190c069 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -6,8 +6,8 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyMapping, PyString}; use crate::build_tools::py_schema_err; -use crate::errors::{ErrorType, ValLineError}; -use crate::input::{Input, JsonInput, JsonObject}; +use crate::errors::{py_err_string, ErrorType, ValError, ValLineError, ValResult}; +use crate::input::{Input, JsonInput, JsonObject, StringMapping}; use crate::tools::{extract_i64, py_err}; /// Used for getting items from python dicts, python objects, or JSON objects, in different ways @@ -109,7 +109,7 @@ impl LookupKey { pub fn py_get_dict_item<'data, 's>( &'s self, dict: &'data PyDict, - ) -> PyResult> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { match self { Self::Simple { py_key, path, .. } => match dict.get_item(py_key) { Some(value) => Ok(Some((path, value))), @@ -143,10 +143,22 @@ impl LookupKey { } } + pub fn py_get_string_mapping_item<'data, 's>( + &'s self, + dict: &'data PyDict, + ) -> ValResult<'data, Option<(&'s LookupPath, StringMapping<'data>)>> { + if let Some((path, py_any)) = self.py_get_dict_item(dict)? { + let value = StringMapping::new_value(py_any)?; + Ok(Some((path, value))) + } else { + Ok(None) + } + } + pub fn py_get_mapping_item<'data, 's>( &'s self, dict: &'data PyMapping, - ) -> PyResult> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { match self { Self::Simple { py_key, path, .. } => match dict.get_item(py_key) { Ok(value) => Ok(Some((path, value))), @@ -184,6 +196,23 @@ impl LookupKey { &'s self, obj: &'data PyAny, kwargs: Option<&'data PyDict>, + ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { + match self._py_get_attr(obj, kwargs) { + Ok(v) => Ok(v), + Err(err) => { + let error = py_err_string(obj.py(), err); + Err(ValError::new( + ErrorType::GetAttributeError { error, context: None }, + obj, + )) + } + } + } + + pub fn _py_get_attr<'data, 's>( + &'s self, + obj: &'data PyAny, + kwargs: Option<&'data PyDict>, ) -> PyResult> { if let Some(dict) = kwargs { if let Ok(Some(item)) = self.py_get_dict_item(dict) { @@ -235,7 +264,7 @@ impl LookupKey { pub fn json_get<'data, 's>( &'s self, dict: &'data JsonObject, - ) -> PyResult> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonInput)>> { match self { Self::Simple { key, path, .. } => match dict.get(key) { Some(value) => Ok(Some((path, value))), diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index e75ff310a..2c0fe4a0a 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -323,6 +323,7 @@ impl Validator for ArgumentsValidator { match args { GenericArguments::Py(a) => process!(a, py_get_dict_item, py_get, py_slice), GenericArguments::Json(a) => process!(a, json_get, json_get, json_slice), + GenericArguments::StringMapping(_) => unimplemented!(), } if !errors.is_empty() { Err(ValError::LineErrors(errors)) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 7b7be282e..117596b9f 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -8,7 +8,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{GenericArguments, Input}; +use crate::input::{BorrowInput, GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; use crate::validators::function::convert_err; @@ -188,15 +188,21 @@ impl Validator for DataclassArgsValidator { kw_value = Some((lookup_path, value)); } } + let kw_value = kw_value + .as_ref() + .map(|(path, value)| (path, value.borrow_input())); match (pos_value, kw_value) { // found both positional and keyword arguments, error (Some(_), Some((_, kw_value))) => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::MultipleArgumentValues, - kw_value, - field.name.clone(), - )); + errors.push( + ValLineError::new_with_loc( + ErrorTypeDefaults::MultipleArgumentValues, + kw_value, + field.name.clone(), + ) + .into_owned(py), + ); } // found a positional argument, validate it (Some(pos_value), None) => match field.validator.validate(py, pos_value, state) { @@ -216,10 +222,12 @@ impl Validator for DataclassArgsValidator { Ok(value) => set_item!(field, value), Err(ValError::LineErrors(line_errors)) => { errors.extend(line_errors.into_iter().map(|err| { - lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name) + lookup_path + .apply_error_loc(err, self.loc_by_alias, &field.name) + .into_owned(py) })); } - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), } } // found neither, check if there is a default value, otherwise error @@ -267,11 +275,14 @@ impl Validator for DataclassArgsValidator { // Unknown / extra field match self.extra_behavior { ExtraBehavior::Forbid => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::UnexpectedKeywordArgument, - value, - raw_key.as_loc_item(), - )); + errors.push( + ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedKeywordArgument, + value, + raw_key.as_loc_item(), + ) + .into_owned(py), + ); } ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { @@ -310,9 +321,24 @@ impl Validator for DataclassArgsValidator { } }}; } + match args { GenericArguments::Py(a) => process!(a, py_get_dict_item, py_get, py_slice), GenericArguments::Json(a) => process!(a, json_get, json_get, json_slice), + GenericArguments::StringMapping(a) => { + // StringMapping cannot pass positional args, so wrap the PyDict + // in a type with guaranteed empty args array for sake of the process + // macro + struct StringMappingArgs<'a> { + args: Option<&'a PyTuple>, + kwargs: Option<&'a PyDict>, + } + let a = StringMappingArgs { + args: None, + kwargs: Some(a), + }; + process!(a, py_get_string_mapping_item, py_get, py_slice); + } } Ok(()) }, diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 6a52b7f79..dc8f03937 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -4,7 +4,11 @@ use pyo3::types::PyDict; use crate::build_tools::is_strict; use crate::errors::{ValError, ValLineError, ValResult}; -use crate::input::{DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator}; +use crate::input::BorrowInput; +use crate::input::{ + DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, + StringMappingGenericIterator, +}; use crate::tools::SchemaDict; @@ -78,6 +82,9 @@ impl Validator for DictValidator { GenericMapping::PyMapping(mapping) => { self.validate_generic_mapping(py, input, MappingGenericIterator::new(mapping)?, state) } + GenericMapping::StringMapping(dict) => { + self.validate_generic_mapping(py, input, StringMappingGenericIterator::new(dict)?, state) + } GenericMapping::PyGetAttr(_, _) => unreachable!(), GenericMapping::JsonObject(json_object) => { self.validate_generic_mapping(py, input, JsonObjectGenericIterator::new(json_object)?, state) @@ -113,9 +120,7 @@ impl DictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - mapping_iter: impl Iterator< - Item = ValResult<'data, (&'data (impl Input<'data> + 'data), &'data (impl Input<'data> + 'data))>, - >, + mapping_iter: impl Iterator>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let output = PyDict::new(py); @@ -125,6 +130,8 @@ impl DictValidator { let value_validator = self.value_validator.as_ref(); for item_result in mapping_iter { let (key, value) = item_result?; + let key = key.borrow_input(); + let value = value.borrow_input(); let output_key = match key_validator.validate(py, key, state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { @@ -132,24 +139,25 @@ impl DictValidator { // these are added in reverse order so [key] is shunted along by the second call errors.push( err.with_outer_location("[key]".into()) - .with_outer_location(key.as_loc_item()), + .with_outer_location(key.as_loc_item()) + .into_owned(py), ); } None } Err(ValError::Omit) => continue, - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), }; let output_value = match value_validator.validate(py, value, state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(key.as_loc_item())); + errors.push(err.with_outer_location(key.as_loc_item()).into_owned(py)); } None } Err(ValError::Omit) => continue, - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), }; if let (Some(key), Some(value)) = (output_key, output_value) { output.set_item(key, value)?; diff --git a/src/validators/function.rs b/src/validators/function.rs index 9f8fe75a7..6334a5e16 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -544,7 +544,7 @@ impl ValidationInfo { context: extra.context.map(Into::into), field_name, data: extra.data.map(Into::into), - mode: extra.mode, + mode: extra.input_type, } } } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index c52910500..bf6d009e1 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -3,7 +3,7 @@ use std::fmt; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::errors::{ErrorMode, ErrorType, LocItem, ValError, ValResult}; +use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{GenericIterator, Input}; use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; @@ -153,7 +153,7 @@ impl ValidatorIterator { return Err(ValidationError::from_val_error( py, "ValidatorIterator".to_object(py), - ErrorMode::Python, + InputType::Python, val_error, None, hide_input_in_errors, @@ -180,7 +180,7 @@ impl ValidatorIterator { return Err(ValidationError::from_val_error( py, "ValidatorIterator".to_object(py), - ErrorMode::Python, + InputType::Python, val_error, None, hide_input_in_errors, @@ -262,7 +262,7 @@ impl InternalValidator { context: extra.context.map(|d| d.into_py(py)), self_instance: extra.self_instance.map(|d| d.into_py(py)), recursion_guard: state.recursion_guard.clone(), - validation_mode: extra.mode, + validation_mode: extra.input_type, hide_input_in_errors, validation_error_cause, } @@ -277,7 +277,7 @@ impl InternalValidator { outer_location: Option, ) -> PyResult { let extra = Extra { - mode: self.validation_mode, + input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, ultra_strict: false, @@ -292,7 +292,7 @@ impl InternalValidator { ValidationError::from_val_error( py, self.name.to_object(py), - ErrorMode::Python, + InputType::Python, e, outer_location, self.hide_input_in_errors, @@ -308,7 +308,7 @@ impl InternalValidator { outer_location: Option, ) -> PyResult { let extra = Extra { - mode: self.validation_mode, + input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, ultra_strict: false, @@ -321,7 +321,7 @@ impl InternalValidator { ValidationError::from_val_error( py, self.name.to_object(py), - ErrorMode::Python, + InputType::Python, e, outer_location, self.hide_input_in_errors, diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index e62620027..828532fe5 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -57,9 +57,9 @@ impl Validator for JsonOrPython { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - match state.extra().mode { + match state.extra().input_type { InputType::Python => self.python.validate(py, input, state), - InputType::Json => self.json.validate(py, input, state), + _ => self.json.validate(py, input, state), } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 9440f1527..de98eb583 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -10,8 +10,8 @@ use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; use crate::definitions::DefinitionsBuilder; -use crate::errors::{ErrorMode, LocItem, ValError, ValResult, ValidationError}; -use crate::input::{Input, InputType}; +use crate::errors::{LocItem, ValError, ValResult, ValidationError}; +use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; @@ -170,7 +170,7 @@ impl SchemaValidator { self_instance, &mut RecursionGuard::default(), ) - .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python)) + .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) } #[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))] @@ -223,8 +223,26 @@ impl SchemaValidator { self_instance, recursion_guard, ) - .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Json)), - Err(err) => Err(self.prepare_validation_err(py, err, ErrorMode::Json)), + .map_err(|e| self.prepare_validation_err(py, e, InputType::Json)), + Err(err) => Err(self.prepare_validation_err(py, err, InputType::Json)), + } + } + + #[pyo3(signature = (input, *, strict=None, context=None))] + pub fn validate_strings( + &self, + py: Python, + input: &PyAny, + strict: Option, + context: Option<&PyAny>, + ) -> PyResult { + let t = InputType::String; + let string_mapping = StringMapping::new_value(input).map_err(|e| self.prepare_validation_err(py, e, t))?; + + let recursion_guard = &mut RecursionGuard::default(); + match self._validate(py, &string_mapping, t, strict, None, context, None, recursion_guard) { + Ok(r) => Ok(r), + Err(e) => Err(self.prepare_validation_err(py, e, t)), } } @@ -241,7 +259,7 @@ impl SchemaValidator { context: Option<&PyAny>, ) -> PyResult { let extra = Extra { - mode: InputType::Python, + input_type: InputType::Python, data: None, strict, from_attributes, @@ -254,13 +272,13 @@ impl SchemaValidator { let mut state = ValidationState::new(extra, &self.definitions, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) - .map_err(|e| self.prepare_validation_err(py, e, ErrorMode::Python)) + .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) } #[pyo3(signature = (*, strict=None, context=None))] pub fn get_default_value(&self, py: Python, strict: Option, context: Option<&PyAny>) -> PyResult { let extra = Extra { - mode: InputType::Python, + input_type: InputType::Python, data: None, strict, from_attributes: None, @@ -276,7 +294,7 @@ impl SchemaValidator { Some(v) => Ok(PySome::new(v).into_py(py)), None => Ok(py.None().into_py(py)), }, - Err(e) => Err(self.prepare_validation_err(py, e, ErrorMode::Python)), + Err(e) => Err(self.prepare_validation_err(py, e, InputType::Python)), } } @@ -305,7 +323,7 @@ impl SchemaValidator { &'data self, py: Python<'data>, input: &'data impl Input<'data>, - mode: InputType, + input_type: InputType, strict: Option, from_attributes: Option, context: Option<&'data PyAny>, @@ -316,18 +334,18 @@ impl SchemaValidator { 's: 'data, { let mut state = ValidationState::new( - Extra::new(strict, from_attributes, context, self_instance, mode), + Extra::new(strict, from_attributes, context, self_instance, input_type), &self.definitions, recursion_guard, ); self.validator.validate(py, input, &mut state) } - fn prepare_validation_err(&self, py: Python, error: ValError, error_mode: ErrorMode) -> PyErr { + fn prepare_validation_err(&self, py: Python, error: ValError, input_type: InputType) -> PyErr { ValidationError::from_val_error( py, self.title.clone_ref(py), - error_mode, + input_type, error, None, self.hide_input_in_errors, @@ -533,7 +551,7 @@ pub fn build_validator<'a>( #[derive(Debug)] pub struct Extra<'a> { /// Validation mode - pub mode: InputType, + pub input_type: InputType, /// This is used as the `data` kwargs to validator functions pub data: Option<&'a PyDict>, /// whether we're in strict or lax mode @@ -554,10 +572,10 @@ impl<'a> Extra<'a> { from_attributes: Option, context: Option<&'a PyAny>, self_instance: Option<&'a PyAny>, - mode: InputType, + input_type: InputType, ) -> Self { Extra { - mode, + input_type, data: None, strict, ultra_strict: false, @@ -571,7 +589,7 @@ impl<'a> Extra<'a> { impl<'a> Extra<'a> { pub fn as_strict(&self, ultra_strict: bool) -> Self { Self { - mode: self.mode, + input_type: self.input_type, data: self.data, strict: Some(true), ultra_strict, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 29e9522f2..f2654c33e 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -7,10 +7,10 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ - AttributesGenericIterator, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, + AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -180,17 +180,13 @@ impl Validator for ModelFieldsValidator { for field in &self.fields { let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { Ok(v) => v, - Err(err) => { - errors.push(ValLineError::new_with_loc( - ErrorType::GetAttributeError { - error: py_err_string(py, err), - context: None, - }, - input, - field.name.clone(), - )); + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(err.with_outer_location(field.name.as_loc_item())); + } continue; } + Err(err) => return ControlFlow::Break(err), }; if let Some((lookup_path, value)) = op_key_value { if let Some(ref mut used_keys) = used_keys { @@ -198,10 +194,7 @@ impl Validator for ModelFieldsValidator { // extra logic either way used_keys.insert(lookup_path.first_key()); } - match field - .validator - .validate(py, value, state) - { + match field.validator.validate(py, value.borrow_input(), state) { Ok(value) => { control_flow!(model_dict.set_item(&field.name_py, value))?; fields_set_vec.push(field.name_py.clone_ref(py)); @@ -209,10 +202,13 @@ impl Validator for ModelFieldsValidator { Err(ValError::Omit) => continue, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + errors.push( + lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name) + .into_owned(py) + ); } } - Err(err) => return ControlFlow::Break(err), + Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { @@ -242,25 +238,31 @@ impl Validator for ModelFieldsValidator { for err in line_errors { errors.push( err.with_outer_location(raw_key.as_loc_item()) - .with_type(ErrorTypeDefaults::InvalidKey), + .with_type(ErrorTypeDefaults::InvalidKey) + .into_owned(py) ); } continue; } - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), }; - if used_keys.contains(either_str.as_cow()?.as_ref()) { + let cow = either_str.as_cow().map_err(|err| err.into_owned(py))?; + if used_keys.contains(cow.as_ref()) { continue; } + let value = value.borrow_input(); // Unknown / extra field match self.extra_behavior { ExtraBehavior::Forbid => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::ExtraForbidden, - value, - raw_key.as_loc_item(), - )); + errors.push( + ValLineError::new_with_loc( + ErrorTypeDefaults::ExtraForbidden, + value, + raw_key.as_loc_item(), + ) + .into_owned(py) + ); } ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { @@ -273,10 +275,10 @@ impl Validator for ModelFieldsValidator { } Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(raw_key.as_loc_item())); + errors.push(err.with_outer_location(raw_key.as_loc_item()).into_owned(py)); } } - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), } } else { model_extra_dict.set_item(py_key, value.to_object(py))?; @@ -293,8 +295,9 @@ impl Validator for ModelFieldsValidator { } match dict { GenericMapping::PyDict(d) => process!(d, py_get_dict_item, DictGenericIterator), - GenericMapping::PyGetAttr(d, kwargs) => process!(d, py_get_attr, AttributesGenericIterator, kwargs), GenericMapping::PyMapping(d) => process!(d, py_get_mapping_item, MappingGenericIterator), + GenericMapping::StringMapping(d) => process!(d, py_get_string_mapping_item, StringMappingGenericIterator), + GenericMapping::PyGetAttr(d, kwargs) => process!(d, py_get_attr, AttributesGenericIterator, kwargs), GenericMapping::JsonObject(d) => process!(d, json_get, JsonObjectGenericIterator), } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index a095a52f1..56e4a8225 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -1,3 +1,5 @@ +use std::ops::ControlFlow; + use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; @@ -6,10 +8,10 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ - AttributesGenericIterator, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, + AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -18,8 +20,6 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -use std::ops::ControlFlow; - #[derive(Debug, Clone)] struct TypedDictField { name: String, @@ -181,17 +181,13 @@ impl Validator for TypedDictValidator { for field in &self.fields { let op_key_value = match field.lookup_key.$get_method($dict $(, $kwargs )? ) { Ok(v) => v, - Err(err) => { - errors.push(ValLineError::new_with_loc( - ErrorType::GetAttributeError { - error: py_err_string(py, err), - context: None, - }, - input, - field.name.clone(), - )); + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(err.with_outer_location(field.name.as_loc_item())); + } continue; } + Err(err) => return ControlFlow::Break(err), }; if let Some((lookup_path, value)) = op_key_value { if let Some(ref mut used_keys) = used_keys { @@ -199,17 +195,21 @@ impl Validator for TypedDictValidator { // extra logic either way used_keys.insert(lookup_path.first_key()); } - match field.validator.validate(py, value, state) { + match field.validator.validate(py, value.borrow_input(), state) { Ok(value) => { control_flow!(output_dict.set_item(&field.name_py, value))?; } Err(ValError::Omit) => continue, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name)); + errors.push( + lookup_path + .apply_error_loc(err, self.loc_by_alias, &field.name) + .into_owned(py) + ); } } - Err(err) => return ControlFlow::Break(err), + Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { @@ -238,25 +238,31 @@ impl Validator for TypedDictValidator { for err in line_errors { errors.push( err.with_outer_location(raw_key.as_loc_item()) - .with_type(ErrorTypeDefaults::InvalidKey), + .with_type(ErrorTypeDefaults::InvalidKey) + .into_owned(py) ); } continue; } - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), }; - if used_keys.contains(either_str.as_cow()?.as_ref()) { + let cow = either_str.as_cow().map_err(|err| err.into_owned(py))?; + if used_keys.contains(cow.as_ref()) { continue; } + let value = value.borrow_input(); // Unknown / extra field match self.extra_behavior { ExtraBehavior::Forbid => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::ExtraForbidden, - value, - raw_key.as_loc_item(), - )); + errors.push( + ValLineError::new_with_loc( + ErrorTypeDefaults::ExtraForbidden, + value, + raw_key.as_loc_item(), + ) + .into_owned(py) + ); } ExtraBehavior::Ignore => {} ExtraBehavior::Allow => { @@ -268,10 +274,14 @@ impl Validator for TypedDictValidator { } Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(raw_key.as_loc_item())); + errors.push( + err + .with_outer_location(raw_key.as_loc_item()) + .into_owned(py) + ); } } - Err(err) => return Err(err), + Err(err) => return Err(err.into_owned(py)), } } else { output_dict.set_item(py_key, value.to_object(py))?; @@ -284,8 +294,9 @@ impl Validator for TypedDictValidator { } match dict { GenericMapping::PyDict(d) => process!(d, py_get_dict_item, DictGenericIterator), - GenericMapping::PyGetAttr(d, kwargs) => process!(d, py_get_attr, AttributesGenericIterator, kwargs), GenericMapping::PyMapping(d) => process!(d, py_get_mapping_item, MappingGenericIterator), + GenericMapping::StringMapping(d) => process!(d, py_get_string_mapping_item, StringMappingGenericIterator), + GenericMapping::PyGetAttr(d, kwargs) => process!(d, py_get_attr, AttributesGenericIterator, kwargs), GenericMapping::JsonObject(d) => process!(d, json_get, JsonObjectGenericIterator), } diff --git a/src/validators/union.rs b/src/validators/union.rs index 385a1f49d..4d3b0bd78 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -455,8 +455,9 @@ impl Validator for TaggedUnionValidator { let dict = input.validate_model_fields(self.strict, from_attributes)?; let tag = match dict { GenericMapping::PyDict(dict) => find_validator!(py_get_dict_item, dict), - GenericMapping::PyGetAttr(obj, kwargs) => find_validator!(py_get_attr, obj, kwargs), GenericMapping::PyMapping(mapping) => find_validator!(py_get_mapping_item, mapping), + GenericMapping::StringMapping(d) => find_validator!(py_get_dict_item, d), + GenericMapping::PyGetAttr(obj, kwargs) => find_validator!(py_get_attr, obj, kwargs), GenericMapping::JsonObject(mapping) => find_validator!(json_get, mapping), }?; self.find_call_validator(py, tag, input, state) diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 75d6dae66..6cf5ce313 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -44,6 +44,7 @@ impl<'a> ValidationState<'a> { &'state mut self, f: impl FnOnce(&mut Extra<'a>), ) -> ValidationStateWithReboundExtra<'state, 'a> { + #[allow(clippy::unnecessary_struct_initialization)] let old_extra = Extra { ..self.extra }; f(&mut self.extra); ValidationStateWithReboundExtra { state: self, old_extra } diff --git a/tests/conftest.py b/tests/conftest.py index a5f5cc344..a83c49472 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,9 @@ def __init__( def validate_python(self, py_input, strict: bool | None = None, context: Any = None): return self.validator.validate_python(py_input, strict=strict, context=context) + def validate_json(self, json_str: str, strict: bool | None = None, context: Any = None): + return self.validator.validate_json(json_str, strict=strict, context=context) + def validate_test(self, py_input, strict: bool | None = None, context: Any = None): if self.validator_type == 'json': return self.validator.validate_json( diff --git a/tests/test_validate_strings.py b/tests/test_validate_strings.py new file mode 100644 index 000000000..0e0350d0b --- /dev/null +++ b/tests/test_validate_strings.py @@ -0,0 +1,121 @@ +import dataclasses +import re +from datetime import date, datetime + +import pytest + +from pydantic_core import SchemaValidator, ValidationError, core_schema + +from .conftest import Err + + +def test_bool(): + v = SchemaValidator(core_schema.bool_schema()) + + assert v.validate_strings('true') is True + assert v.validate_strings('true', strict=True) is True + assert v.validate_strings('false') is False + + +@pytest.mark.parametrize( + 'schema,input_value,expected,strict', + [ + (core_schema.int_schema(), '1', 1, False), + (core_schema.int_schema(), '1', 1, True), + (core_schema.int_schema(), 'xxx', Err('type=int_parsing'), True), + (core_schema.float_schema(), '1.1', 1.1, False), + (core_schema.float_schema(), '1.10', 1.1, False), + (core_schema.float_schema(), '1.1', 1.1, True), + (core_schema.float_schema(), '1.10', 1.1, True), + (core_schema.date_schema(), '2017-01-01', date(2017, 1, 1), False), + (core_schema.date_schema(), '2017-01-01', date(2017, 1, 1), True), + (core_schema.datetime_schema(), '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), False), + (core_schema.datetime_schema(), '2017-01-01T12:13:14.567', datetime(2017, 1, 1, 12, 13, 14, 567_000), True), + (core_schema.date_schema(), '2017-01-01T12:13:14.567', Err('type=date_from_datetime_inexact'), False), + (core_schema.date_schema(), '2017-01-01T12:13:14.567', Err('type=date_parsing'), True), + (core_schema.date_schema(), '2017-01-01T00:00:00', date(2017, 1, 1), False), + (core_schema.date_schema(), '2017-01-01T00:00:00', Err('type=date_parsing'), True), + ], + ids=repr, +) +def test_validate_strings(schema, input_value, expected, strict): + v = SchemaValidator(schema) + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)): + v.validate_strings(input_value, strict=strict) + else: + assert v.validate_strings(input_value, strict=strict) == expected + + +def test_dict(): + v = SchemaValidator(core_schema.dict_schema(core_schema.int_schema(), core_schema.date_schema())) + + assert v.validate_strings({'1': '2017-01-01', '2': '2017-01-02'}) == {1: date(2017, 1, 1), 2: date(2017, 1, 2)} + assert v.validate_strings({'1': '2017-01-01', '2': '2017-01-02'}, strict=True) == { + 1: date(2017, 1, 1), + 2: date(2017, 1, 2), + } + + +def test_model(): + class MyModel: + # this is not required, but it avoids `__pydantic_fields_set__` being included in `__dict__` + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + field_a: int + field_b: date + + v = SchemaValidator( + core_schema.model_schema( + MyModel, + core_schema.model_fields_schema( + { + 'field_a': core_schema.model_field(core_schema.int_schema()), + 'field_b': core_schema.model_field(core_schema.date_schema()), + } + ), + ) + ) + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}) + assert m2.__dict__ == {'field_a': 1, 'field_b': date(2017, 1, 1)} + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}, strict=True) + assert m2.__dict__ == {'field_a': 1, 'field_b': date(2017, 1, 1)} + + +def test_dataclass(): + @dataclasses.dataclass + class MyDataClass: + field_a: int + field_b: date + + v = SchemaValidator( + core_schema.dataclass_schema( + MyDataClass, + core_schema.dataclass_args_schema( + 'MyDataClass', + [ + core_schema.dataclass_field('field_a', core_schema.int_schema()), + core_schema.dataclass_field('field_b', core_schema.date_schema()), + ], + ), + ['field_a', 'field_b'], + ) + ) + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}) + assert m2.__dict__ == {'field_a': 1, 'field_b': date(2017, 1, 1)} + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}, strict=True) + assert m2.__dict__ == {'field_a': 1, 'field_b': date(2017, 1, 1)} + + +def test_typed_dict(): + v = SchemaValidator( + core_schema.typed_dict_schema( + { + 'field_a': core_schema.typed_dict_field(core_schema.int_schema()), + 'field_b': core_schema.typed_dict_field(core_schema.date_schema()), + } + ) + ) + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}) + assert m2 == {'field_a': 1, 'field_b': date(2017, 1, 1)} + m2 = v.validate_strings({'field_a': '1', 'field_b': '2017-01-01'}, strict=True) + assert m2 == {'field_a': 1, 'field_b': date(2017, 1, 1)} diff --git a/tests/validators/test_bool.py b/tests/validators/test_bool.py index 714c07ffe..e71d41cb1 100644 --- a/tests/validators/test_bool.py +++ b/tests/validators/test_bool.py @@ -88,7 +88,8 @@ def test_bool_key(py_and_json: PyAndJson): assert v.validate_test({'true': 1, 'off': 2}) == {True: 1, False: 2} assert v.validate_test({'true': 1, 'off': 2}, strict=False) == {True: 1, False: 2} with pytest.raises(ValidationError, match='Input should be a valid boolean'): - v.validate_test({'true': 1, 'off': 2}, strict=True) + v.validate_python({'true': 1, 'off': 2}, strict=True) + assert v.validate_json('{"true": 1, "off": 2}', strict=True) == {True: 1, False: 2} def test_validate_assignment_not_supported() -> None: diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 80b96eacd..74f0024ca 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -215,7 +215,8 @@ def test_float_key(py_and_json: PyAndJson): assert v.validate_test({'1': 1, '2': 2}) == {1: 1, 2: 2} assert v.validate_test({'1.5': 1, '2.4': 2}) == {1.5: 1, 2.4: 2} with pytest.raises(ValidationError, match='Input should be a valid number'): - v.validate_test({'1.5': 1, '2.5': 2}, strict=True) + v.validate_python({'1.5': 1, '2.5': 2}, strict=True) + assert v.validate_json('{"1.5": 1, "2.5": 2}', strict=True) == {1.5: 1, 2.5: 2} @pytest.mark.parametrize( diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 43cd5bacb..8d5850dc8 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -402,7 +402,8 @@ def test_int_key(py_and_json: PyAndJson): v = py_and_json({'type': 'dict', 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}}) assert v.validate_test({'1': 1, '2': 2}) == {1: 1, 2: 2} with pytest.raises(ValidationError, match='Input should be a valid integer'): - v.validate_test({'1': 1, '2': 2}, strict=True) + v.validate_python({'1': 1, '2': 2}, strict=True) + assert v.validate_json('{"1": 1, "2": 2}', strict=True) == {1: 1, 2: 2} def test_string_as_int_with_underscores() -> None: diff --git a/tests/validators/test_json.py b/tests/validators/test_json.py index 83ddca172..d8666d335 100644 --- a/tests/validators/test_json.py +++ b/tests/validators/test_json.py @@ -50,26 +50,33 @@ def test_any(py_and_json: PyAndJson, input_value, expected): [ ('{"a": 1}', {'a': 1}), (b'{"a": 1}', {'a': 1}), + ( + '🐈 Hello \ud800World', + Err( + 'Input should be a valid string, unable to parse raw data as a unicode string ' + "[type=string_unicode, input_value='🐈 Hello \\ud800World', input_type=str]" + ), + ), (bytearray(b'{"a": 1}'), {'a': 1}), ( 'xx', Err( 'Invalid JSON: expected value at line 1 column 1 ' - "[type=json_invalid, input_value='xx', input_type=str" + "[type=json_invalid, input_value='xx', input_type=str]" ), ), ( b'xx', Err( 'Invalid JSON: expected value at line 1 column 1 ' - "[type=json_invalid, input_value=b'xx', input_type=bytes" + "[type=json_invalid, input_value=b'xx', input_type=bytes]" ), ), ( bytearray(b'xx'), Err( 'Invalid JSON: expected value at line 1 column 1 ' - "[type=json_invalid, input_value=bytearray(b'xx'), input_type=bytearray" + "[type=json_invalid, input_value=bytearray(b'xx'), input_type=bytearray]" ), ), ], From bed050bcd41ed23e302835f5e1ba415db44fa9c1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 20 Sep 2023 03:15:31 -0400 Subject: [PATCH 041/550] Fix serialization of model subclasses via unions when definition referneces are used (#977) --- src/serializers/shared.rs | 4 +- src/serializers/type_serializers/dataclass.rs | 4 +- .../type_serializers/definitions.rs | 6 ++ src/serializers/type_serializers/model.rs | 4 +- src/serializers/type_serializers/nullable.rs | 6 +- src/serializers/type_serializers/union.rs | 14 +++-- .../type_serializers/with_default.rs | 6 +- tests/serializers/test_union.py | 55 ++++++++++++++++++- 8 files changed, 80 insertions(+), 19 deletions(-) diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 7c24ff6db..b9b0c1fe1 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -13,7 +13,7 @@ use serde_json::ser::PrettyFormatter; use crate::build_tools::py_schema_err; use crate::build_tools::py_schema_error_type; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; use crate::tools::{py_err, SchemaDict}; @@ -293,7 +293,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug { fn get_name(&self) -> &str; /// Used by union serializers to decide if it's worth trying again while allowing subclasses - fn retry_with_lax_check(&self) -> bool { + fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { false } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 787e267dd..124f962ad 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::tools::SchemaDict; use super::{ @@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer { &self.name } - fn retry_with_lax_check(&self) -> bool { + fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { true } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 19d6e75ed..4614bbc56 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -4,6 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; +use crate::definitions::Definitions; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -96,4 +97,9 @@ impl TypeSerializer for DefinitionRefSerializer { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { + let comb_serializer = definitions.get(self.serializer_id).unwrap(); + comb_serializer.retry_with_lax_check(definitions) + } } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 8a2eeb4e1..c5b252fbf 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -13,7 +13,7 @@ use super::{ }; use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::serializers::errors::PydanticSerializationUnexpectedValue; use crate::tools::SchemaDict; @@ -228,7 +228,7 @@ impl TypeSerializer for ModelSerializer { &self.name } - fn retry_with_lax_check(&self) -> bool { + fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { true } } diff --git a/src/serializers/type_serializers/nullable.rs b/src/serializers/type_serializers/nullable.rs index 23349ec81..837d6c5f1 100644 --- a/src/serializers/type_serializers/nullable.rs +++ b/src/serializers/type_serializers/nullable.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::tools::SchemaDict; use super::{infer_json_key_known, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, TypeSerializer}; @@ -75,7 +75,7 @@ impl TypeSerializer for NullableSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self) -> bool { - self.serializer.retry_with_lax_check() + fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { + self.serializer.retry_with_lax_check(definitions) } } diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index d9b185017..70818959e 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; use crate::build_tools::py_schema_err; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; @@ -87,7 +87,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check() { + if self.retry_with_lax_check(extra.definitions) { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -116,7 +116,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check() { + if self.retry_with_lax_check(extra.definitions) { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.json_key(key, &new_extra) { @@ -153,7 +153,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check() { + if self.retry_with_lax_check(extra.definitions) { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -174,8 +174,10 @@ impl TypeSerializer for UnionSerializer { &self.name } - fn retry_with_lax_check(&self) -> bool { - self.choices.iter().any(TypeSerializer::retry_with_lax_check) + fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { + self.choices + .iter() + .any(|choice| choice.retry_with_lax_check(definitions)) } } diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index d20c273a1..148c05052 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::tools::SchemaDict; use crate::validators::DefaultType; @@ -67,8 +67,8 @@ impl TypeSerializer for WithDefaultSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self) -> bool { - self.serializer.retry_with_lax_check() + fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { + self.serializer.retry_with_lax_check(definitions) } fn get_default(&self, py: Python) -> PyResult> { diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index c59996669..f81e33a6b 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -1,7 +1,7 @@ import dataclasses import json import re -from typing import Union +from typing import Any, ClassVar, Union import pytest from typing_extensions import Literal @@ -401,3 +401,56 @@ class Model(BaseModel): assert s.to_python(m) == {'value': data, 'value_types_reversed': data} assert s.to_json(m) == f'{{"value":{json_value},"value_types_reversed":{json_value}}}'.encode() + + +def test_union_serializes_model_subclass_from_definition() -> None: + class BaseModel: + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + + def __init__(self, **kwargs: Any): + for key, value in kwargs.items(): + setattr(self, key, value) + + class User(BaseModel): + name: str + + class DBUser(User): + password: str + __pydantic_serializer__: ClassVar[SchemaSerializer] + + DBUser.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema( + DBUser, + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(core_schema.str_schema()), + 'password': core_schema.model_field(core_schema.str_schema()), + } + ), + ) + ) + + class Item(BaseModel): + price: float + + s = SchemaSerializer( + core_schema.definitions_schema( + core_schema.union_schema( + [core_schema.definition_reference_schema('User'), core_schema.definition_reference_schema('Item')] + ), + [ + core_schema.model_schema( + User, + core_schema.model_fields_schema({'name': core_schema.model_field(core_schema.str_schema())}), + ref='User', + ), + core_schema.model_schema( + Item, + core_schema.model_fields_schema({'price': core_schema.model_field(core_schema.float_schema())}), + ref='Item', + ), + ], + ) + ) + + assert s.to_python(DBUser(name='John', password='secret')) == {'name': 'John'} From 157a64389de6bf58f0ae990ec1f7fe55237b8009 Mon Sep 17 00:00:00 2001 From: Serge Matveenko Date: Wed, 20 Sep 2023 14:26:19 +0200 Subject: [PATCH 042/550] =?UTF-8?q?=E2=9C=A8=20Implement=20optional=20`num?= =?UTF-8?q?ber`=20to=20`str`=20coercion=20(#975)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- python/pydantic_core/core_schema.py | 2 ++ src/input/input_abstract.rs | 6 ++-- src/input/input_json.rs | 6 +++- src/input/input_python.rs | 11 ++++++- src/input/return_enums.rs | 6 ++++ src/validators/string.rs | 13 ++++++-- src/validators/url.rs | 4 +-- tests/validators/test_function.py | 10 +++++-- tests/validators/test_string.py | 46 ++++++++++++++++++++++++++++- tests/validators/test_union.py | 5 +++- 10 files changed, 96 insertions(+), 13 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index db0b1bd54..6c44f034e 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -63,6 +63,7 @@ class CoreConfig(TypedDict, total=False): hide_input_in_errors: Whether to hide input data from `ValidationError` representation. validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. Requires exceptiongroup backport pre Python 3.11. + coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). """ title: str @@ -95,6 +96,7 @@ class CoreConfig(TypedDict, total=False): # used to hide input data from ValidationError repr hide_input_in_errors: bool validation_error_cause: bool # default: False + coerce_numbers_to_str: bool # default: False IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index f4a760a45..655ba24b9 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -91,16 +91,16 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn parse_json(&'a self) -> ValResult<'a, JsonInput>; - fn validate_str(&'a self, strict: bool) -> ValResult> { + fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult> { if strict { self.strict_str() } else { - self.lax_str() + self.lax_str(coerce_numbers_to_str) } } fn strict_str(&'a self) -> ValResult>; #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_str(&'a self) -> ValResult> { + fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult> { self.strict_str() } diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 07f3554e6..e375f5755 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -88,9 +88,13 @@ impl<'a> Input<'a> for JsonInput { _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn lax_str(&'a self) -> ValResult> { + fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { match self { JsonInput::String(s) => Ok(s.as_str().into()), + JsonInput::BigInt(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonInput::Float(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonInput::Int(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonInput::Uint(v) if coerce_numbers_to_str => Ok(v.to_string().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 9e5fd1d1f..cf84c5517 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -223,7 +223,7 @@ impl<'a> Input<'a> for PyAny { } } - fn lax_str(&'a self) -> ValResult> { + fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { if let Ok(py_str) = ::try_from_exact(self) { Ok(py_str.into()) } else if let Ok(py_str) = self.downcast::() { @@ -246,6 +246,15 @@ impl<'a> Input<'a> for PyAny { Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), }; Ok(s.into()) + } else if coerce_numbers_to_str && { + let py = self.py(); + let decimal_type: Py = get_decimal_type(py); + + self.is_instance_of::() + || self.is_instance_of::() + || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() + } { + Ok(self.str()?.into()) } else { Err(ValError::new(ErrorTypeDefaults::StringType, self)) } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 56f86f580..d7d5a7f3b 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -768,6 +768,12 @@ impl<'a> From<&'a str> for EitherString<'a> { } } +impl<'a> From for EitherString<'a> { + fn from(data: String) -> Self { + Self::Cow(Cow::Owned(data)) + } +} + impl<'a> From<&'a PyString> for EitherString<'a> { fn from(date: &'a PyString) -> Self { Self::Py(date) diff --git a/src/validators/string.rs b/src/validators/string.rs index 7b1a75a6d..0b4ec4a7b 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -13,6 +13,7 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationSta #[derive(Debug, Clone)] pub struct StrValidator { strict: bool, + coerce_numbers_to_str: bool, } impl BuildValidator for StrValidator { @@ -30,6 +31,7 @@ impl BuildValidator for StrValidator { } else { Ok(Self { strict: con_str_validator.strict, + coerce_numbers_to_str: con_str_validator.coerce_numbers_to_str, } .into()) } @@ -45,7 +47,7 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict))?; + let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; Ok(either_str.into_py(py)) } @@ -76,6 +78,7 @@ pub struct StrConstrainedValidator { strip_whitespace: bool, to_lower: bool, to_upper: bool, + coerce_numbers_to_str: bool, } impl_py_gc_traverse!(StrConstrainedValidator {}); @@ -87,7 +90,7 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict))?; + let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; let cow = either_str.as_cow()?; let mut str = cow.as_ref(); if self.strip_whitespace { @@ -188,6 +191,11 @@ impl StrConstrainedValidator { let to_upper: bool = schema_or_config(schema, config, intern!(py, "to_upper"), intern!(py, "str_to_upper"))?.unwrap_or(false); + let coerce_numbers_to_str = config + .and_then(|c| c.get_item("coerce_numbers_to_str")) + .and_then(|v| v.is_true().ok()) + .unwrap_or(false); + Ok(Self { strict: is_strict(schema, config)?, pattern, @@ -196,6 +204,7 @@ impl StrConstrainedValidator { strip_whitespace, to_lower, to_upper, + coerce_numbers_to_str, }) } diff --git a/src/validators/url.rs b/src/validators/url.rs index a261d5e03..0afc76e59 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -111,7 +111,7 @@ impl Validator for UrlValidator { impl UrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, Url> { - match input.validate_str(strict) { + match input.validate_str(strict, false) { Ok(either_str) => { let cow = either_str.as_cow()?; let url_str = cow.as_ref(); @@ -251,7 +251,7 @@ impl Validator for MultiHostUrlValidator { impl MultiHostUrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, PyMultiHostUrl> { - match input.validate_str(strict) { + match input.validate_str(strict, false) { Ok(either_str) => { let cow = either_str.as_cow()?; let url_str = cow.as_ref(); diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 07710be41..1a63c27a3 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -201,7 +201,10 @@ def f(input_value, validator, info): {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} ) - assert v.validate_python('input value') == 'ValidatorCallable(Str(StrValidator{strict:false}))' + assert ( + v.validate_python('input value') + == 'ValidatorCallable(Str(StrValidator{strict:false,coerce_numbers_to_str:false}))' + ) def test_function_wrap_str(): @@ -212,7 +215,10 @@ def f(input_value, validator, info): {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} ) - assert v.validate_python('input value') == 'ValidatorCallable(Str(StrValidator{strict:false}))' + assert ( + v.validate_python('input value') + == 'ValidatorCallable(Str(StrValidator{strict:false,coerce_numbers_to_str:false}))' + ) def test_function_wrap_not_callable(): diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index 8eae0ab65..0cff9eb95 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -1,5 +1,6 @@ import re from decimal import Decimal +from numbers import Number from typing import Any, Dict, Union import pytest @@ -201,7 +202,10 @@ def test_regex_error(): def test_default_validator(): v = SchemaValidator(core_schema.str_schema(strict=True, to_lower=False), {'str_strip_whitespace': False}) - assert plain_repr(v) == 'SchemaValidator(title="str",validator=Str(StrValidator{strict:true}),definitions=[])' + assert ( + plain_repr(v) + == 'SchemaValidator(title="str",validator=Str(StrValidator{strict:true,coerce_numbers_to_str:false}),definitions=[])' + ) @pytest.fixture(scope='session', name='FruitEnum') @@ -253,3 +257,43 @@ class StrSubclass(str): assert not isinstance(v.validate_python(StrSubclass('')), StrSubclass) assert not isinstance(v.validate_python(StrSubclass(''), strict=True), StrSubclass) + + +def test_coerce_numbers_to_str_disabled_in_strict_mode() -> None: + config = core_schema.CoreConfig(coerce_numbers_to_str=True) + + v = SchemaValidator(core_schema.str_schema(strict=True), config) + with pytest.raises(ValidationError): + v.validate_python(42) + with pytest.raises(ValidationError): + v.validate_json('42') + + +@pytest.mark.parametrize( + ('number', 'expected_str'), + [ + pytest.param(42, '42', id='42'), + pytest.param(42.0, '42.0', id='42.0'), + pytest.param(Decimal('42.0'), '42.0', id="Decimal('42.0')"), + ], +) +def test_coerce_numbers_to_str(number: Number, expected_str: str) -> None: + config = core_schema.CoreConfig(coerce_numbers_to_str=True) + + v = SchemaValidator(core_schema.str_schema(), config) + assert v.validate_python(number) == expected_str + + +@pytest.mark.parametrize( + ('number', 'expected_str'), + [ + pytest.param('42', '42', id='42'), + pytest.param('42.0', '42', id='42.0'), + pytest.param('42.13', '42.13', id='42.13'), + ], +) +def test_coerce_numbers_to_str_from_json(number: str, expected_str: str) -> None: + config = core_schema.CoreConfig(coerce_numbers_to_str=True) + + v = SchemaValidator(core_schema.str_schema(), config) + assert v.validate_json(number) == expected_str diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 7e0f608f1..05072b806 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -256,7 +256,10 @@ def test_empty_choices(): def test_one_choice(): v = SchemaValidator({'type': 'union', 'choices': [{'type': 'str'}]}) - assert plain_repr(v) == 'SchemaValidator(title="str",validator=Str(StrValidator{strict:false}),definitions=[])' + assert ( + plain_repr(v) + == 'SchemaValidator(title="str",validator=Str(StrValidator{strict:false,coerce_numbers_to_str:false}),definitions=[])' + ) assert v.validate_python('hello') == 'hello' From 621fb035f24e243a45f3e17aaeef2d9afc31a7de Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 20 Sep 2023 09:42:26 -0500 Subject: [PATCH 043/550] Add support for hiding input in errors and json (#973) Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- python/pydantic_core/_pydantic_core.pyi | 15 +++++++-- src/build_tools.rs | 2 +- src/errors/validation_exception.rs | 45 ++++++++++++++++--------- tests/test_errors.py | 18 ++++++++++ 4 files changed, 61 insertions(+), 19 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index be2b64793..ca0104b9d 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -731,18 +731,28 @@ class ValidationError(ValueError): Returns: The number of errors in the validation error. """ - def errors(self, *, include_url: bool = True, include_context: bool = True) -> list[ErrorDetails]: + def errors( + self, *, include_url: bool = True, include_context: bool = True, include_input: bool = True + ) -> list[ErrorDetails]: """ Details about each error in the validation error. Args: include_url: Whether to include a URL to documentation on the error each error. include_context: Whether to include the context of each error. + include_input: Whether to include the input value of each error. Returns: A list of [`ErrorDetails`][pydantic_core.ErrorDetails] for each error in the validation error. """ - def json(self, *, indent: int | None = None, include_url: bool = True, include_context: bool = True) -> str: + def json( + self, + *, + indent: int | None = None, + include_url: bool = True, + include_context: bool = True, + include_input: bool = True, + ) -> str: """ Same as [`errors()`][pydantic_core.ValidationError.errors] but returns a JSON string. @@ -750,6 +760,7 @@ class ValidationError(ValueError): indent: The number of spaces to indent the JSON by, or `None` for no indentation - compact JSON. include_url: Whether to include a URL to documentation on the error each error. include_context: Whether to include the context of each error. + include_input: Whether to include the input value of each error. Returns: a JSON string. diff --git a/src/build_tools.rs b/src/build_tools.rs index 47fa569ac..c242f97a3 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -125,7 +125,7 @@ impl SchemaError { fn errors(&self, py: Python) -> PyResult> { match &self.0 { SchemaErrorEnum::Message(_) => Ok(PyList::empty(py).into_py(py)), - SchemaErrorEnum::ValidationError(error) => error.errors(py, false, false), + SchemaErrorEnum::ValidationError(error) => error.errors(py, false, false, true), } } diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index d0977001a..09154d8ac 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -265,8 +265,14 @@ impl ValidationError { self.line_errors.len() } - #[pyo3(signature = (*, include_url = true, include_context = true))] - pub fn errors(&self, py: Python, include_url: bool, include_context: bool) -> PyResult> { + #[pyo3(signature = (*, include_url = true, include_context = true, include_input = true))] + pub fn errors( + &self, + py: Python, + include_url: bool, + include_context: bool, + include_input: bool, + ) -> PyResult> { let url_prefix = get_url_prefix(py, include_url); let mut iteration_error = None; let list = PyList::new( @@ -278,7 +284,7 @@ impl ValidationError { if iteration_error.is_some() { return py.None(); } - e.as_dict(py, url_prefix, include_context, self.input_type) + e.as_dict(py, url_prefix, include_context, self.input_type, include_input) .unwrap_or_else(|err| { iteration_error = Some(err); py.None() @@ -292,13 +298,14 @@ impl ValidationError { } } - #[pyo3(signature = (*, indent = None, include_url = true, include_context = true))] + #[pyo3(signature = (*, indent = None, include_url = true, include_context = true, include_input = true))] pub fn json<'py>( &self, py: Python<'py>, indent: Option, include_url: bool, include_context: bool, + include_input: bool, ) -> PyResult<&'py PyString> { let state = SerializationState::new("iso8601", "utf8")?; let extra = state.extra(py, &SerMode::Json, true, false, false, true, None); @@ -307,6 +314,7 @@ impl ValidationError { line_errors: &self.line_errors, url_prefix: get_url_prefix(py, include_url), include_context, + include_input, extra: &extra, input_type: &self.input_type, }; @@ -477,12 +485,15 @@ impl PyLineError { url_prefix: Option<&str>, include_context: bool, input_type: InputType, + include_input: bool, ) -> PyResult { let dict = PyDict::new(py); dict.set_item("type", self.error_type.type_string())?; dict.set_item("loc", self.location.to_object(py))?; dict.set_item("msg", self.error_type.render_message(py, input_type)?)?; - dict.set_item("input", &self.input_value)?; + if include_input { + dict.set_item("input", &self.input_value)?; + } if include_context { if let Some(context) = self.error_type.py_dict(py)? { dict.set_item("ctx", context)?; @@ -563,6 +574,7 @@ struct ValidationErrorSerializer<'py> { line_errors: &'py [PyLineError], url_prefix: Option<&'py str>, include_context: bool, + include_input: bool, extra: &'py crate::serializers::Extra<'py>, input_type: &'py InputType, } @@ -579,6 +591,7 @@ impl<'py> Serialize for ValidationErrorSerializer<'py> { line_error, url_prefix: self.url_prefix, include_context: self.include_context, + include_input: self.include_input, extra: self.extra, input_type: self.input_type, }; @@ -593,6 +606,7 @@ struct PyLineErrorSerializer<'py> { line_error: &'py PyLineError, url_prefix: Option<&'py str>, include_context: bool, + include_input: bool, extra: &'py crate::serializers::Extra<'py>, input_type: &'py InputType, } @@ -603,13 +617,10 @@ impl<'py> Serialize for PyLineErrorSerializer<'py> { S: Serializer, { let py = self.py; - let mut size = 4; - if self.url_prefix.is_some() { - size += 1; - } - if self.include_context { - size += 1; - } + let size = 3 + [self.url_prefix.is_some(), self.include_context, self.include_input] + .into_iter() + .filter(|b| *b) + .count(); let mut map = serializer.serialize_map(Some(size))?; map.serialize_entry("type", &self.line_error.error_type.type_string())?; @@ -623,10 +634,12 @@ impl<'py> Serialize for PyLineErrorSerializer<'py> { .map_err(py_err_json::)?; map.serialize_entry("msg", &msg)?; - map.serialize_entry( - "input", - &self.extra.serialize_infer(self.line_error.input_value.as_ref(py)), - )?; + if self.include_input { + map.serialize_entry( + "input", + &self.extra.serialize_infer(self.line_error.input_value.as_ref(py)), + )?; + } if self.include_context { if let Some(context) = self.line_error.error_type.py_dict(py).map_err(py_err_json::)? { diff --git a/tests/test_errors.py b/tests/test_errors.py index fe71b9860..15d2c78af 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1070,3 +1070,21 @@ def test_loc_with_dots(pydantic_version): "[type=int_parsing, input_value='x', input_type=str]\n" f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing' ) + + +def test_hide_input_in_error() -> None: + s = SchemaValidator({'type': 'int'}) + with pytest.raises(ValidationError) as exc_info: + s.validate_python('definitely not an int') + + for error in exc_info.value.errors(include_input=False): + assert 'input' not in error + + +def test_hide_input_in_json() -> None: + s = SchemaValidator({'type': 'int'}) + with pytest.raises(ValidationError) as exc_info: + s.validate_python('definitely not an int') + + for error in exc_info.value.errors(include_input=False): + assert 'input' not in error From 0f9a5c9569b04a787275410bf5d8764f194c3bca Mon Sep 17 00:00:00 2001 From: Serge Matveenko Date: Wed, 20 Sep 2023 16:48:01 +0200 Subject: [PATCH 044/550] =?UTF-8?q?=F0=9F=94=96=20Bump=20version=20to=202.?= =?UTF-8?q?9.0=20(#979)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 55224a954..1027b8597 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.8.0" +version = "2.9.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index a4dfda154..9e193b3d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.8.0" +version = "2.9.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From b8d3b95ac27759db8df8def65b100b03db19eb31 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 20 Sep 2023 16:03:02 +0100 Subject: [PATCH 045/550] use `TypedDict` from `typing_extensions` on <3.12 (#978) --- python/pydantic_core/core_schema.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 6c44f034e..3abf393b8 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -11,15 +11,20 @@ from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + if sys.version_info < (3, 11): from typing_extensions import Protocol, Required, TypeAlias else: from typing import Protocol, Required, TypeAlias if sys.version_info < (3, 9): - from typing_extensions import Literal, TypedDict + from typing_extensions import Literal else: - from typing import Literal, TypedDict + from typing import Literal if TYPE_CHECKING: from pydantic_core import PydanticUndefined From 4c84ed877c6806a9cca6649756160619a79c7e70 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 21 Sep 2023 14:22:20 +0100 Subject: [PATCH 046/550] add switch to change regex engine from Rust to Python (#983) --- python/pydantic_core/core_schema.py | 10 ++++ src/validators/string.rs | 62 ++++++++++++++++++++++--- tests/validators/test_string.py | 71 +++++++++++++++++++++++------ 3 files changed, 123 insertions(+), 20 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 3abf393b8..718bdf969 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -69,6 +69,7 @@ class CoreConfig(TypedDict, total=False): validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. Requires exceptiongroup backport pre Python 3.11. coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). + regex_engine: The regex engine to use for regex pattern validation. Default is 'rust-regex'. See `StringSchema`. """ title: str @@ -752,6 +753,7 @@ class StringSchema(TypedDict, total=False): strip_whitespace: bool to_lower: bool to_upper: bool + regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' strict: bool ref: str metadata: Any @@ -766,6 +768,7 @@ def str_schema( strip_whitespace: bool | None = None, to_lower: bool | None = None, to_upper: bool | None = None, + regex_engine: Literal['rust-regex', 'python-re'] | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, @@ -789,6 +792,12 @@ def str_schema( strip_whitespace: Whether to strip whitespace from the value to_lower: Whether to convert the value to lowercase to_upper: Whether to convert the value to uppercase + regex_engine: The regex engine to use for pattern validation. Default is 'rust-regex'. + - `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust + crate, which is non-backtracking and therefore more DDoS + resistant, but does not support all regex features. + - `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module, + which supports all regex features, but may be slower. strict: Whether the value should be a string or a value that can be converted to a string ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -802,6 +811,7 @@ def str_schema( strip_whitespace=strip_whitespace, to_lower=to_lower, to_upper=to_upper, + regex_engine=regex_engine, strict=strict, ref=ref, metadata=metadata, diff --git a/src/validators/string.rs b/src/validators/string.rs index 0b4ec4a7b..d997dcf85 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -72,7 +72,7 @@ impl Validator for StrValidator { #[derive(Debug, Clone, Default)] pub struct StrConstrainedValidator { strict: bool, - pattern: Option, + pattern: Option, max_length: Option, min_length: Option, strip_whitespace: bool, @@ -126,10 +126,10 @@ impl Validator for StrConstrainedValidator { } if let Some(pattern) = &self.pattern { - if !pattern.is_match(str) { + if !pattern.is_match(py, str)? { return Err(ValError::new( ErrorType::StringPatternMismatch { - pattern: pattern.to_string(), + pattern: pattern.pattern.clone(), context: None, }, input, @@ -170,10 +170,16 @@ impl Validator for StrConstrainedValidator { impl StrConstrainedValidator { fn build(schema: &PyDict, config: Option<&PyDict>) -> PyResult { let py = schema.py(); - let pattern = match schema.get_as(intern!(py, "pattern"))? { - Some(s) => Some(Regex::new(s).map_err(|e| py_schema_error_type!("{}", e))?), - None => None, - }; + + let pattern = schema + .get_as(intern!(py, "pattern"))? + .map(|s| { + let regex_engine = + schema_or_config(schema, config, intern!(py, "regex_engine"), intern!(py, "regex_engine"))? + .unwrap_or(RegexEngine::RUST_REGEX); + Pattern::compile(py, s, regex_engine) + }) + .transpose()?; let min_length: Option = schema_or_config(schema, config, intern!(py, "min_length"), intern!(py, "str_min_length"))?; let max_length: Option = @@ -219,3 +225,45 @@ impl StrConstrainedValidator { || self.to_upper } } + +#[derive(Debug, Clone)] +struct Pattern { + pattern: String, + engine: RegexEngine, +} + +#[derive(Debug, Clone)] +enum RegexEngine { + RustRegex(Regex), + PythonRe(PyObject), +} + +impl RegexEngine { + const RUST_REGEX: &str = "rust-regex"; + const PYTHON_RE: &str = "python-re"; +} + +impl Pattern { + fn compile(py: Python<'_>, pattern: String, engine: &str) -> PyResult { + let engine = match engine { + RegexEngine::RUST_REGEX => { + RegexEngine::RustRegex(Regex::new(&pattern).map_err(|e| py_schema_error_type!("{}", e))?) + } + RegexEngine::PYTHON_RE => { + let re_compile = py.import(intern!(py, "re"))?.getattr(intern!(py, "compile"))?; + RegexEngine::PythonRe(re_compile.call1((&pattern,))?.into()) + } + _ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)), + }; + Ok(Self { pattern, engine }) + } + + fn is_match(&self, py: Python<'_>, target: &str) -> PyResult { + match &self.engine { + RegexEngine::RustRegex(regex) => Ok(regex.is_match(target)), + RegexEngine::PythonRe(py_regex) => { + Ok(!py_regex.call_method1(py, intern!(py, "match"), (target,))?.is_none(py)) + } + } + } +} diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index 0cff9eb95..acb145a58 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -167,7 +167,8 @@ def test_str_constrained_config(): v.validate_python('test long') -def test_invalid_regex(): +@pytest.mark.parametrize('engine', [None, 'rust-regex', 'python-re']) +def test_invalid_regex(engine): # TODO uncomment and fix once #150 is done # with pytest.raises(SchemaError) as exc_info: # SchemaValidator({'type': 'str', 'pattern': 123}) @@ -175,18 +176,25 @@ def test_invalid_regex(): # 'Error building "str" validator:\n TypeError: \'int\' object cannot be converted to \'PyString\'' # ) with pytest.raises(SchemaError) as exc_info: - SchemaValidator({'type': 'str', 'pattern': '(abc'}) - assert exc_info.value.args[0] == ( - 'Error building "str" validator:\n' - ' SchemaError: regex parse error:\n' - ' (abc\n' - ' ^\n' - 'error: unclosed group' - ) - - -def test_regex_error(): - v = SchemaValidator({'type': 'str', 'pattern': '11'}) + SchemaValidator(core_schema.str_schema(pattern='(abc', regex_engine=engine)) + + if engine is None or engine == 'rust-regex': + assert exc_info.value.args[0] == ( + 'Error building "str" validator:\n' + ' SchemaError: regex parse error:\n' + ' (abc\n' + ' ^\n' + 'error: unclosed group' + ) + elif engine == 'python-re': + assert exc_info.value.args[0] == ( + 'Error building "str" validator:\n error: missing ), unterminated subpattern at position 0' + ) + + +@pytest.mark.parametrize('engine', [None, 'rust-regex', 'python-re']) +def test_regex_error(engine): + v = SchemaValidator(core_schema.str_schema(pattern='11', regex_engine=engine)) with pytest.raises(ValidationError) as exc_info: v.validate_python('12') assert exc_info.value.errors(include_url=False) == [ @@ -297,3 +305,40 @@ def test_coerce_numbers_to_str_from_json(number: str, expected_str: str) -> None v = SchemaValidator(core_schema.str_schema(), config) assert v.validate_json(number) == expected_str + + +@pytest.mark.parametrize('mode', (None, 'schema', 'config')) +def test_backtracking_regex_rust_unsupported(mode) -> None: + pattern = r'r(#*)".*?"\1' + + with pytest.raises(SchemaError) as exc_info: + if mode is None: + # rust-regex is the default + SchemaValidator(core_schema.str_schema(pattern=pattern)) + elif mode == 'schema': + SchemaValidator(core_schema.str_schema(pattern=pattern, regex_engine='rust-regex')) + elif mode == 'config': + SchemaValidator(core_schema.str_schema(pattern=pattern), core_schema.CoreConfig(regex_engine='rust-regex')) + + assert exc_info.value.args[0] == ( + 'Error building \"str\" validator:\n' + ' SchemaError: regex parse error:\n' + ' r(#*)\".*?\"\\1\n' + ' ^^\n' + 'error: backreferences are not supported' + ) + + +@pytest.mark.parametrize('mode', ('schema', 'config')) +def test_backtracking_regex_python(mode) -> None: + pattern = r'r(#*)".*?"\1' + + if mode == 'schema': + v = SchemaValidator(core_schema.str_schema(pattern=pattern, regex_engine='python-re')) + elif mode == 'config': + v = SchemaValidator(core_schema.str_schema(pattern=pattern), core_schema.CoreConfig(regex_engine='python-re')) + assert v.validate_python('r""') == 'r""' + assert v.validate_python('r#""#') == 'r#""#' + with pytest.raises(ValidationError): + # not a valid match for the pattern + v.validate_python('r#"#') From 33a7cc0727bef0d5a0a293e0662dfd81af35658f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 21 Sep 2023 15:16:19 +0100 Subject: [PATCH 047/550] =?UTF-8?q?=F0=9F=90=9B=20Fix=20handling=20of=20`U?= =?UTF-8?q?UID`=20values=20having=20`UUID.version=3DNone`=20(#981)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Serge Matveenko Co-authored-by: David Hewitt <1939362+davidhewitt@users.noreply.github.com> --- src/validators/uuid.rs | 14 ++++++++------ tests/validators/test_uuid.py | 18 +++++++++++++++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 6467cc406..ca924ce66 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -59,7 +59,7 @@ impl From for Version { #[derive(Debug, Clone)] pub struct UuidValidator { strict: bool, - version: Option, + version: Option, } impl BuildValidator for UuidValidator { @@ -71,10 +71,11 @@ impl BuildValidator for UuidValidator { _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); + // Note(lig): let's keep this conversion through the Version enum just for the sake of validation let version = schema.get_as::(intern!(py, "version"))?.map(Version::from); Ok(Self { strict: is_strict(schema, config)?, - version, + version: version.map(usize::from), } .into()) } @@ -92,9 +93,11 @@ impl Validator for UuidValidator { let class = get_uuid_type(py)?; if let Some(py_input) = input.input_is_instance(class) { if let Some(expected_version) = self.version { - let py_input_version: usize = py_input.getattr(intern!(py, "version"))?.extract()?; - let expected_version = usize::from(expected_version); - if expected_version != py_input_version { + let py_input_version: Option = py_input.getattr(intern!(py, "version"))?.extract()?; + if !match py_input_version { + Some(py_input_version) => py_input_version == expected_version, + None => false, + } { return Err(ValError::new( ErrorType::UuidVersion { expected_version, @@ -179,7 +182,6 @@ impl UuidValidator { if let Some(expected_version) = self.version { let v1 = uuid.get_version_num(); - let expected_version = usize::from(expected_version); if v1 != expected_version { return Err(ValError::new( ErrorType::UuidVersion { diff --git a/tests/validators/test_uuid.py b/tests/validators/test_uuid.py index 4ecd7feff..9afe2d2b8 100644 --- a/tests/validators/test_uuid.py +++ b/tests/validators/test_uuid.py @@ -23,6 +23,7 @@ ('6ba7b810-9dad-11d1-80b4-00c04fd430c8', UUID('6ba7b810-9dad-11d1-80b4-00c04fd430c8')), ('886313e1-3b8a-5372-9b90-0c9aee199e5d', UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d')), ('c0a8f9a8-aa5e-482b-a067-9cb3a51f5c11', UUID('c0a8f9a8-aa5e-482b-a067-9cb3a51f5c11')), + ('00000000-8000-4000-8000-000000000000', UUID('00000000-8000-4000-8000-000000000000')), (b'\x12\x34\x56\x78' * 4, UUID('12345678-1234-5678-1234-567812345678')), (b'\x00\x00\x00\x00' * 4, UUID('00000000-0000-0000-0000-000000000000')), (b'ebcdab58-6eb8-46fb-a190-d07a33e9eac8', UUID('ebcdab58-6eb8-46fb-a190-d07a33e9eac8')), @@ -118,6 +119,16 @@ def test_uuid_strict(input_value, expected): (UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7'), 4, UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7')), ('0e7ac198-9acd-4c0c-b4b4-761974bf71d7', 4, UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7')), (UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7'), 4, UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7')), + # Cases from pydantic#7355 and pydantic#7537 + # `UUID.version` makes sense for RFC 4122 UUIDs only. For non RFC 4122 UUIDs Python uses `UUID.version=None` + ('00000000-8000-4000-8000-000000000000', 4, UUID('00000000-8000-4000-8000-000000000000')), + (UUID('00000000-8000-4000-8000-000000000000'), 4, UUID('00000000-8000-4000-8000-000000000000')), + ('00000000-7fff-4000-7fff-000000000000', None, UUID('00000000-7fff-4000-7fff-000000000000')), + (UUID('00000000-7fff-4000-7fff-000000000000'), None, UUID('00000000-7fff-4000-7fff-000000000000')), + (UUID('00000000-7fff-4000-7fff-000000000000'), 4, Err('UUID version 4 expected')), + ('b34b6755-f49c-3bd2-6f06-131a708c2bf3', None, UUID('b34b6755-f49c-3bd2-6f06-131a708c2bf3')), + (UUID('b34b6755-f49c-3bd2-6f06-131a708c2bf3'), None, UUID('b34b6755-f49c-3bd2-6f06-131a708c2bf3')), + (UUID('b34b6755-f49c-3bd2-6f06-131a708c2bf3'), 4, Err('UUID version 4 expected')), # Invalid UUIDs ('a6cc5730-2261-11ee-9c43-2eb5a363657c', 5, Err('UUID version 5 expected')), (UUID('a6cc5730-2261-11ee-9c43-2eb5a363657c'), 5, Err('UUID version 5 expected')), @@ -130,7 +141,12 @@ def test_uuid_strict(input_value, expected): ], ) def test_uuid_version(input_value, version, expected): - v = SchemaValidator({'type': 'uuid', 'version': version}) + schema = {'type': 'uuid'} + if version is not None: + schema['version'] = version + + v = SchemaValidator(schema) + if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) From 916d9092fb99d61dea12d17bede18e39fa6719e8 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 21 Sep 2023 15:50:24 +0100 Subject: [PATCH 048/550] Add `validate_core_schema` function and remove validation from `SchemaValidator` and `SchemaSerializer` constructors (#982) Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- benches/main.rs | 8 +++++--- python/pydantic_core/__init__.py | 2 ++ python/pydantic_core/_pydantic_core.pyi | 9 +++++++++ src/lib.rs | 3 ++- src/serializers/mod.rs | 5 +---- src/validators/mod.rs | 9 ++++++--- src/validators/string.rs | 4 ++-- tests/benchmarks/test_complete_benchmark.py | 20 ++++++++++---------- tests/conftest.py | 4 ++-- tests/serializers/test_definitions.py | 4 ++-- tests/serializers/test_dict.py | 6 ++++-- tests/serializers/test_list_tuple.py | 10 ++++++---- tests/serializers/test_misc.py | 4 ++-- tests/test.rs | 4 ++-- tests/test_build.py | 18 +++++++++--------- tests/validators/test_custom_error.py | 4 ++-- tests/validators/test_date.py | 18 ++++++++++-------- tests/validators/test_datetime.py | 10 +++++----- tests/validators/test_decimal.py | 6 ++++-- tests/validators/test_definitions.py | 4 ++-- tests/validators/test_function.py | 8 ++++---- tests/validators/test_model_fields.py | 16 +++++++++------- tests/validators/test_time.py | 7 +++---- tests/validators/test_timedelta.py | 7 +++---- tests/validators/test_typed_dict.py | 18 ++++++++++-------- tests/validators/test_union.py | 4 ++-- 26 files changed, 118 insertions(+), 94 deletions(-) diff --git a/benches/main.rs b/benches/main.rs index d5ab30479..8e020e269 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -7,10 +7,11 @@ use test::{black_box, Bencher}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; -use _pydantic_core::SchemaValidator; +use _pydantic_core::{validate_core_schema, SchemaValidator}; fn build_schema_validator_with_globals(py: Python, code: &str, globals: Option<&PyDict>) -> SchemaValidator { - let schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap(); + let mut schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap(); + schema = validate_core_schema(py, schema).unwrap().extract().unwrap(); SchemaValidator::py_new(py, schema, None).unwrap() } @@ -444,7 +445,8 @@ fn complete_model(bench: &mut Bencher) { sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); let complete_schema = py.import("complete_schema").unwrap(); - let schema = complete_schema.call_method0("schema").unwrap(); + let mut schema = complete_schema.call_method0("schema").unwrap(); + schema = validate_core_schema(py, schema).unwrap().extract().unwrap(); let validator = SchemaValidator::py_new(py, schema, None).unwrap(); let input = complete_schema.call_method0("input_data_lax").unwrap(); diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index a916f43d8..a46a77b7d 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -24,6 +24,7 @@ __version__, to_json, to_jsonable_python, + validate_core_schema, ) from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType @@ -63,6 +64,7 @@ 'TzInfo', 'to_json', 'to_jsonable_python', + 'validate_core_schema', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index ca0104b9d..82e6e3015 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -44,6 +44,7 @@ __all__ = [ 'to_jsonable_python', 'list_all_errors', 'TzInfo', + 'validate_core_schema', ] __version__: str build_profile: str @@ -836,3 +837,11 @@ class TzInfo(datetime.tzinfo): def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ... def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ... + +def validate_core_schema(schema: CoreSchema) -> CoreSchema: + """Validate a CoreSchema + This currently uses lax mode for validation (i.e. will coerce strings to dates and such) + but may use strict mode in the future. + We may also remove this function altogether, do not rely on it being present if you are + using pydantic-core directly. + """ diff --git a/src/lib.rs b/src/lib.rs index cbf668b9d..b241cdb8a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ pub use errors::{ pub use serializers::{ to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer, }; -pub use validators::{PySome, SchemaValidator}; +pub use validators::{validate_core_schema, PySome, SchemaValidator}; pub fn get_pydantic_core_version() -> &'static str { static PYDANTIC_CORE_VERSION: OnceLock = OnceLock::new(); @@ -97,5 +97,6 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(to_json, m)?)?; m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; + m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; Ok(()) } diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 29a54fe05..6dbc076fe 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -7,7 +7,6 @@ use pyo3::{PyTraverseError, PyVisit}; use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; -use crate::validators::SelfValidator; use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; @@ -73,9 +72,7 @@ impl SchemaSerializer { #[pymethods] impl SchemaSerializer { #[new] - pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult { - let self_validator = SelfValidator::new(py)?; - let schema = self_validator.validate_schema(py, schema)?; + pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; diff --git a/src/validators/mod.rs b/src/validators/mod.rs index de98eb583..c7ce0b9e9 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -113,9 +113,6 @@ pub struct SchemaValidator { impl SchemaValidator { #[new] pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult { - let self_validator = SelfValidator::new(py)?; - let schema = self_validator.validate_schema(py, schema)?; - let mut definitions_builder = DefinitionsBuilder::new(); let mut validator = build_validator(schema, config, &mut definitions_builder)?; @@ -411,6 +408,12 @@ impl<'py> SelfValidator<'py> { } } +#[pyfunction] +pub fn validate_core_schema<'a>(py: Python<'a>, schema: &'a PyAny) -> PyResult<&'a PyAny> { + let self_validator = SelfValidator::new(py)?; + self_validator.validate_schema(py, schema) +} + pub trait BuildValidator: Sized { const EXPECTED_TYPE: &'static str; diff --git a/src/validators/string.rs b/src/validators/string.rs index d997dcf85..6b646224d 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -239,8 +239,8 @@ enum RegexEngine { } impl RegexEngine { - const RUST_REGEX: &str = "rust-regex"; - const PYTHON_RE: &str = "python-re"; + const RUST_REGEX: &'static str = "rust-regex"; + const PYTHON_RE: &'static str = "python-re"; } impl Pattern { diff --git a/tests/benchmarks/test_complete_benchmark.py b/tests/benchmarks/test_complete_benchmark.py index 8ed07d37c..57fb7645c 100644 --- a/tests/benchmarks/test_complete_benchmark.py +++ b/tests/benchmarks/test_complete_benchmark.py @@ -8,7 +8,7 @@ import pytest -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, validate_core_schema from .complete_schema import input_data_lax, input_data_strict, input_data_wrong, schema @@ -16,7 +16,7 @@ def test_complete_valid(): lax_schema = schema() cls = lax_schema['cls'] - lax_validator = SchemaValidator(lax_schema) + lax_validator = SchemaValidator(validate_core_schema(lax_schema)) output = lax_validator.validate_python(input_data_lax()) assert isinstance(output, cls) assert len(output.__pydantic_fields_set__) == 41 @@ -73,14 +73,14 @@ def test_complete_valid(): }, } - strict_validator = SchemaValidator(schema(strict=True)) + strict_validator = SchemaValidator(validate_core_schema(schema(strict=True))) output2 = strict_validator.validate_python(input_data_strict()) assert output_dict == output2.__dict__ def test_complete_invalid(): lax_schema = schema() - lax_validator = SchemaValidator(lax_schema) + lax_validator = SchemaValidator(validate_core_schema(lax_schema)) with pytest.raises(ValidationError) as exc_info: lax_validator.validate_python(input_data_wrong()) assert len(exc_info.value.errors(include_url=False)) == 739 @@ -88,19 +88,19 @@ def test_complete_invalid(): @pytest.mark.benchmark(group='complete') def test_complete_core_lax(benchmark): - v = SchemaValidator(schema()) + v = SchemaValidator(validate_core_schema(schema())) benchmark(v.validate_python, input_data_lax()) @pytest.mark.benchmark(group='complete') def test_complete_core_strict(benchmark): - v = SchemaValidator(schema(strict=True)) + v = SchemaValidator(validate_core_schema(schema(strict=True))) benchmark(v.validate_python, input_data_strict()) @pytest.mark.benchmark(group='complete-wrong') def test_complete_core_error(benchmark): - v = SchemaValidator(schema()) + v = SchemaValidator(validate_core_schema(schema())) data = input_data_wrong() @benchmark @@ -115,7 +115,7 @@ def f(): @pytest.mark.benchmark(group='complete-wrong') def test_complete_core_isinstance(benchmark): - v = SchemaValidator(schema()) + v = SchemaValidator(validate_core_schema(schema())) data = input_data_wrong() assert v.isinstance_python(data) is False @@ -135,7 +135,7 @@ def default_json_encoder(obj): @pytest.mark.benchmark(group='complete-json') def test_complete_core_json(benchmark): - v = SchemaValidator(schema()) + v = SchemaValidator(validate_core_schema(schema())) json_data = json.dumps(input_data_lax(), default=default_json_encoder) benchmark(v.validate_json, json_data) @@ -143,4 +143,4 @@ def test_complete_core_json(benchmark): @pytest.mark.benchmark(group='build') def test_build_schema(benchmark): lax_schema = schema() - benchmark(SchemaValidator, lax_schema) + benchmark(lambda s: SchemaValidator(validate_core_schema(s)), lax_schema) diff --git a/tests/conftest.py b/tests/conftest.py index a83c49472..bdf59a73c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ import pytest from typing_extensions import Literal -from pydantic_core import ArgsKwargs, SchemaValidator, ValidationError +from pydantic_core import ArgsKwargs, SchemaValidator, ValidationError, validate_core_schema from pydantic_core.core_schema import CoreConfig __all__ = 'Err', 'PyAndJson', 'plain_repr', 'infinite_generator' @@ -53,7 +53,7 @@ class PyAndJsonValidator: def __init__( self, schema, config: CoreConfig | None = None, *, validator_type: Literal['json', 'python'] | None = None ): - self.validator = SchemaValidator(schema, config) + self.validator = SchemaValidator(validate_core_schema(schema), config) self.validator_type = validator_type def validate_python(self, py_input, strict: bool | None = None, context: Any = None): diff --git a/tests/serializers/test_definitions.py b/tests/serializers/test_definitions.py index 61b1e9cf7..2da4d353d 100644 --- a/tests/serializers/test_definitions.py +++ b/tests/serializers/test_definitions.py @@ -1,6 +1,6 @@ import pytest -from pydantic_core import SchemaError, SchemaSerializer, core_schema +from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema def test_custom_ser(): @@ -25,7 +25,7 @@ def test_ignored_def(): def test_def_error(): with pytest.raises(SchemaError) as exc_info: - SchemaSerializer( + validate_core_schema( core_schema.definitions_schema( core_schema.list_schema(core_schema.definition_reference_schema('foobar')), [core_schema.int_schema(ref='foobar'), {'type': 'wrong'}], diff --git a/tests/serializers/test_dict.py b/tests/serializers/test_dict.py index ed627830d..cc1b56091 100644 --- a/tests/serializers/test_dict.py +++ b/tests/serializers/test_dict.py @@ -3,7 +3,7 @@ import pytest from dirty_equals import IsStrictDict -from pydantic_core import SchemaError, SchemaSerializer, core_schema +from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema def test_dict_str_int(): @@ -155,4 +155,6 @@ def test_filter_runtime_int(): ) def test_include_error(include_value, error_msg): with pytest.raises(SchemaError, match=error_msg): - SchemaSerializer(core_schema.dict_schema(serialization=core_schema.filter_dict_schema(include=include_value))) + validate_core_schema( + core_schema.dict_schema(serialization=core_schema.filter_dict_schema(include=include_value)) + ) diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index 9149941e1..c695bbcc4 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -4,7 +4,7 @@ import pytest -from pydantic_core import SchemaError, SchemaSerializer, core_schema +from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema def test_list_any(): @@ -144,8 +144,10 @@ def test_exclude(schema_func, seq_f): @pytest.mark.parametrize('include,exclude', [({1, 3, 5}, {5, 6}), ([1, 3, 5], [5, 6])]) def test_filter(include, exclude): v = SchemaSerializer( - core_schema.list_schema( - core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include, exclude=exclude) + validate_core_schema( + core_schema.list_schema( + core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include, exclude=exclude) + ) ) ) assert v.to_python([0, 1, 2, 3, 4, 5, 6, 7]) == [1, 3] @@ -186,7 +188,7 @@ class RemovedContains(ImplicitContains): @pytest.mark.parametrize('schema_func', [core_schema.list_schema, core_schema.tuple_variable_schema]) def test_include_error(schema_func, include_value, error_msg): with pytest.raises(SchemaError, match=error_msg): - SchemaSerializer( + validate_core_schema( schema_func(core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include_value)) ) diff --git a/tests/serializers/test_misc.py b/tests/serializers/test_misc.py index ec9f97cb3..7753f1948 100644 --- a/tests/serializers/test_misc.py +++ b/tests/serializers/test_misc.py @@ -1,6 +1,6 @@ import pytest -from pydantic_core import SchemaError, SchemaSerializer, core_schema +from pydantic_core import SchemaError, core_schema, validate_core_schema @pytest.mark.parametrize( @@ -12,4 +12,4 @@ ) def test_invalid_ser_schema(ser_schema, msg): with pytest.raises(SchemaError, match=msg): - SchemaSerializer(core_schema.any_schema(serialization=ser_schema)) + validate_core_schema(core_schema.any_schema(serialization=ser_schema)) diff --git a/tests/test.rs b/tests/test.rs index 9b2fb99b5..526b30e5e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -46,7 +46,7 @@ mod tests { ] }"#; let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap(); - SchemaSerializer::py_new(py, schema, None).unwrap(); + SchemaSerializer::py_new(schema, None).unwrap(); }); } @@ -77,7 +77,7 @@ a = A() py.run(code, None, Some(locals)).unwrap(); let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap(); let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap(); - let serialized: Vec = SchemaSerializer::py_new(py, schema, None) + let serialized: Vec = SchemaSerializer::py_new(schema, None) .unwrap() .to_json(py, a, None, None, None, true, false, false, false, false, true, None) .unwrap() diff --git a/tests/test_build.py b/tests/test_build.py index 095eb6887..b81179de1 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -2,23 +2,23 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator +from pydantic_core import SchemaError, SchemaValidator, validate_core_schema from pydantic_core import core_schema as cs def test_build_error_type(): with pytest.raises(SchemaError, match="Input tag 'foobar' found using 'type' does not match any of the"): - SchemaValidator({'type': 'foobar', 'title': 'TestModel'}) + validate_core_schema({'type': 'foobar', 'title': 'TestModel'}) def test_build_error_internal(): with pytest.raises(SchemaError, match='Input should be a valid integer, unable to parse string as an integer'): - SchemaValidator({'type': 'str', 'min_length': 'xxx', 'title': 'TestModel'}) + validate_core_schema({'type': 'str', 'min_length': 'xxx', 'title': 'TestModel'}) def test_build_error_deep(): with pytest.raises(SchemaError, match='Input should be a valid integer, unable to parse string as an integer'): - SchemaValidator( + validate_core_schema( { 'title': 'MyTestModel', 'type': 'typed-dict', @@ -34,7 +34,7 @@ def test_schema_as_string(): def test_schema_wrong_type(pydantic_version): with pytest.raises(SchemaError) as exc_info: - SchemaValidator(1) + validate_core_schema(1) assert str(exc_info.value) == ( 'Invalid Schema:\n Input should be a valid dictionary or object to' ' extract fields from [type=model_attributes_type, input_value=1, input_type=int]\n' @@ -66,7 +66,7 @@ def test_schema_definition_error(): schema = {'type': 'union', 'choices': []} schema['choices'].append({'type': 'nullable', 'schema': schema}) with pytest.raises(SchemaError, match='Recursion error - cyclic reference detected'): - SchemaValidator(schema) + validate_core_schema(schema) def test_not_schema_definition_error(): @@ -83,17 +83,17 @@ def test_not_schema_definition_error(): def test_no_type(): with pytest.raises(SchemaError, match="Unable to extract tag using discriminator 'type'"): - SchemaValidator({}) + validate_core_schema({}) def test_wrong_type(): with pytest.raises(SchemaError, match="Input tag 'unknown' found using 'type' does not match any of the"): - SchemaValidator({'type': 'unknown'}) + validate_core_schema({'type': 'unknown'}) def test_function_no_mode(): with pytest.raises(SchemaError, match="Input tag 'function' found using 'type' does not match any of the"): - SchemaValidator({'type': 'function'}) + validate_core_schema({'type': 'function'}) def test_try_self_schema_discriminator(): diff --git a/tests/validators/test_custom_error.py b/tests/validators/test_custom_error.py index 4c787ad0f..f2a6e45b5 100644 --- a/tests/validators/test_custom_error.py +++ b/tests/validators/test_custom_error.py @@ -1,6 +1,6 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import PyAndJson @@ -33,7 +33,7 @@ def test_custom_error_type(py_and_json: PyAndJson): def test_custom_error_error(): with pytest.raises(SchemaError, match=r'custom_error_type\s+Field required \[type=missing'): - SchemaValidator({'type': 'custom-error', 'schema': {'type': 'int'}}) + validate_core_schema({'type': 'custom-error', 'schema': {'type': 'int'}}) def test_custom_error_invalid(): diff --git a/tests/validators/test_date.py b/tests/validators/test_date.py index 616771bf0..5ddde4884 100644 --- a/tests/validators/test_date.py +++ b/tests/validators/test_date.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal @@ -5,7 +7,7 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import Err, PyAndJson @@ -183,9 +185,9 @@ def test_date_strict_json_ctx(): '2000-01-02', Err('Input should be less than or equal to 2000-01-01 [type=less_than_equal,'), ), - ({'lt': '2000-01-01'}, '1999-12-31', date(1999, 12, 31)), - ({'lt': '2000-01-01'}, '2000-01-01', Err('Input should be less than 2000-01-01 [type=less_than,')), - ({'ge': '2000-01-01'}, '2000-01-01', date(2000, 1, 1)), + ({'lt': date(2000, 1, 1)}, '1999-12-31', date(1999, 12, 31)), + ({'lt': date(2000, 1, 1)}, '2000-01-01', Err('Input should be less than 2000-01-01 [type=less_than,')), + ({'ge': date(2000, 1, 1)}, '2000-01-01', date(2000, 1, 1)), ( {'ge': date(2000, 1, 1)}, '1999-12-31', @@ -195,8 +197,8 @@ def test_date_strict_json_ctx(): ({'gt': date(2000, 1, 1)}, '2000-01-01', Err('Input should be greater than 2000-01-01 [type=greater_than,')), ], ) -def test_date_kwargs(kwargs: Dict[str, Any], input_value, expected): - v = SchemaValidator({'type': 'date', **kwargs}) +def test_date_kwargs(kwargs: Dict[str, Any], input_value: date, expected: Err | date): + v = SchemaValidator({'type': 'date', **kwargs}) # type: ignore if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) @@ -207,7 +209,7 @@ def test_date_kwargs(kwargs: Dict[str, Any], input_value, expected): def test_invalid_constraint(): with pytest.raises(SchemaError, match=r'date\.gt\n Input should be a valid date or datetime'): - SchemaValidator({'type': 'date', 'gt': 'foobar'}) + validate_core_schema({'type': 'date', 'gt': 'foobar'}) def test_dict_py(): @@ -288,4 +290,4 @@ def test_date_past_future_today(): def test_offset_too_large(): with pytest.raises(SchemaError, match=r'Input should be less than 86400 \[type=less_than,'): - SchemaValidator(core_schema.date_schema(now_op='past', now_utc_offset=24 * 3600)) + validate_core_schema(core_schema.date_schema(now_op='past', now_utc_offset=24 * 3600)) diff --git a/tests/validators/test_datetime.py b/tests/validators/test_datetime.py index 41044e23f..df04d1631 100644 --- a/tests/validators/test_datetime.py +++ b/tests/validators/test_datetime.py @@ -10,7 +10,7 @@ import pytest import pytz -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import Err, PyAndJson @@ -284,7 +284,7 @@ def test_union(): def test_invalid_constraint(): with pytest.raises(SchemaError, match=r'datetime\.gt\n Input should be a valid datetime'): - SchemaValidator({'type': 'datetime', 'gt': 'foobar'}) + validate_core_schema({'type': 'datetime', 'gt': 'foobar'}) @pytest.mark.parametrize( @@ -387,7 +387,7 @@ def test_mock_utc_offset_8_hours(mocker): def test_offset_too_large(): with pytest.raises(SchemaError, match=r'Input should be greater than -86400 \[type=greater_than,'): - SchemaValidator(core_schema.datetime_schema(now_op='past', now_utc_offset=-24 * 3600)) + validate_core_schema(core_schema.datetime_schema(now_op='past', now_utc_offset=-24 * 3600)) def test_raises_schema_error_for_unknown_constraint_kind(): @@ -395,7 +395,7 @@ def test_raises_schema_error_for_unknown_constraint_kind(): SchemaError, match=(r'Input should be \'aware\' or \'naive\' \[type=literal_error, input_value=\'foo\', input_type=str\]'), ): - SchemaValidator({'type': 'datetime', 'tz_constraint': 'foo'}) + validate_core_schema({'type': 'datetime', 'tz_constraint': 'foo'}) def test_aware(): @@ -477,7 +477,7 @@ def test_tz_constraint_too_high(): def test_tz_constraint_wrong(): with pytest.raises(SchemaError, match="Input should be 'aware' or 'naive"): - SchemaValidator(core_schema.datetime_schema(tz_constraint='wrong')) + validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong')) def test_tz_pickle() -> None: diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index d871877c8..376a9816a 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import math import re @@ -184,8 +186,8 @@ def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_va ], ids=repr, ) -def test_decimal_multiple_of(py_and_json: PyAndJson, multiple_of, input_value, error): - v = py_and_json({'type': 'decimal', 'multiple_of': multiple_of}) +def test_decimal_multiple_of(py_and_json: PyAndJson, multiple_of: float, input_value: float, error: Err | None): + v = py_and_json({'type': 'decimal', 'multiple_of': Decimal(str(multiple_of))}) if error: with pytest.raises(ValidationError, match=re.escape(error.message)): v.validate_test(input_value) diff --git a/tests/validators/test_definitions.py b/tests/validators/test_definitions.py index b6ac5d133..d742da5dd 100644 --- a/tests/validators/test_definitions.py +++ b/tests/validators/test_definitions.py @@ -1,6 +1,6 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, core_schema +from pydantic_core import SchemaError, SchemaValidator, core_schema, validate_core_schema from ..conftest import plain_repr @@ -47,7 +47,7 @@ def test_check_ref_used_ignores_metadata(): def test_def_error(): with pytest.raises(SchemaError) as exc_info: - SchemaValidator( + validate_core_schema( core_schema.definitions_schema( core_schema.list_schema(core_schema.definition_reference_schema('foobar')), [core_schema.int_schema(ref='foobar'), {'type': 'wrong'}], diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 1a63c27a3..1b35a9c25 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -7,7 +7,7 @@ import pytest from dirty_equals import HasRepr -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import plain_repr @@ -223,12 +223,12 @@ def f(input_value, validator, info): def test_function_wrap_not_callable(): with pytest.raises(SchemaError, match='function-wrap.function.typed-dict.function\n Input should be callable'): - SchemaValidator( + validate_core_schema( {'type': 'function-wrap', 'function': {'type': 'general', 'function': []}, 'schema': {'type': 'str'}} ) with pytest.raises(SchemaError, match='function-wrap.function\n Field required'): - SchemaValidator({'type': 'function-wrap', 'schema': {'type': 'str'}}) + validate_core_schema({'type': 'function-wrap', 'schema': {'type': 'str'}}) def test_wrap_error(): @@ -450,7 +450,7 @@ def f(input_value): def test_plain_with_schema(): with pytest.raises(SchemaError, match='function-plain.schema\n Extra inputs are not permitted'): - SchemaValidator( + validate_core_schema( { 'type': 'function-plain', 'function': {'type': 'general', 'function': lambda x: x}, diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index dbba463e2..05201759a 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -8,7 +8,7 @@ import pytest from dirty_equals import FunctionCheck, HasRepr, IsStr -from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import Err, PyAndJson @@ -430,7 +430,7 @@ def test_json_error(): def test_missing_schema_key(): with pytest.raises(SchemaError, match='model-fields.fields.x.schema\n Field required'): - SchemaValidator({'type': 'model-fields', 'fields': {'x': {'type': 'str'}}}) + validate_core_schema({'type': 'model-fields', 'fields': {'x': {'type': 'str'}}}) def test_fields_required_by_default(): @@ -734,10 +734,12 @@ def test_paths_allow_by_name(py_and_json: PyAndJson, input_value): def test_alias_build_error(alias_schema, error): with pytest.raises(SchemaError, match=error): SchemaValidator( - { - 'type': 'model-fields', - 'fields': {'field_a': {'type': 'model-field', 'schema': {'type': 'int'}, **alias_schema}}, - } + validate_core_schema( + { + 'type': 'model-fields', + 'fields': {'field_a': {'type': 'model-field', 'schema': {'type': 'int'}, **alias_schema}}, + } + ) ) @@ -1491,7 +1493,7 @@ def test_bad_default_factory(default_factory, error_message): class TestOnError: def test_on_error_bad_name(self): with pytest.raises(SchemaError, match="Input should be 'raise', 'omit' or 'default'"): - SchemaValidator( + validate_core_schema( { 'type': 'model-fields', 'fields': { diff --git a/tests/validators/test_time.py b/tests/validators/test_time.py index 29c45dc80..360cda644 100644 --- a/tests/validators/test_time.py +++ b/tests/validators/test_time.py @@ -5,7 +5,7 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import Err, PyAndJson @@ -156,7 +156,6 @@ def test_time_strict_json(input_value, expected): ({'ge': time(1)}, '00:59', Err('Input should be greater than or equal to 01:00:00')), ({'gt': time(12, 13, 14, 123_456)}, '12:13:14.123457', time(12, 13, 14, 123_457)), ({'gt': time(12, 13, 14, 123_456)}, '12:13:14.123456', Err('Input should be greater than 12:13:14.123456')), - ({'gt': '12:13:14.123456'}, '12:13:14.123456', Err('Input should be greater than 12:13:14.123456')), ], ) def test_time_kwargs(kwargs: Dict[str, Any], input_value, expected): @@ -192,7 +191,7 @@ def test_time_bound_ctx(): def test_invalid_constraint(): with pytest.raises(SchemaError, match='Input should be in a valid time format'): - SchemaValidator({'type': 'time', 'gt': 'foobar'}) + validate_core_schema({'type': 'time', 'gt': 'foobar'}) def test_dict_py(): @@ -294,4 +293,4 @@ def test_tz_constraint_too_high(): def test_tz_constraint_wrong(): with pytest.raises(SchemaError, match="Input should be 'aware' or 'naive"): - SchemaValidator(core_schema.time_schema(tz_constraint='wrong')) + validate_core_schema(core_schema.time_schema(tz_constraint='wrong')) diff --git a/tests/validators/test_timedelta.py b/tests/validators/test_timedelta.py index 2bceebf85..3af2a0857 100644 --- a/tests/validators/test_timedelta.py +++ b/tests/validators/test_timedelta.py @@ -5,7 +5,7 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError +from pydantic_core import SchemaError, SchemaValidator, ValidationError, validate_core_schema from ..conftest import Err, PyAndJson @@ -143,7 +143,6 @@ def test_timedelta_strict_json(input_value, expected): ({'ge': timedelta(days=3)}, 'P3D', timedelta(days=3)), ({'ge': timedelta(days=3)}, 'P2DT1H', Err('Input should be greater than or equal to 3 days')), ({'gt': timedelta(days=3)}, 'P3DT1H', timedelta(days=3, hours=1)), - ({'gt': 'P3D'}, 'P2DT1H', Err('Input should be greater than 3 days')), ({'le': timedelta(seconds=-86400.123)}, '-PT86400.123S', timedelta(seconds=-86400.123)), ({'le': timedelta(seconds=-86400.123)}, '-PT86400.124S', timedelta(seconds=-86400.124)), ( @@ -197,10 +196,10 @@ def test_timedelta_kwargs_strict(): def test_invalid_constraint(): with pytest.raises(SchemaError, match='timedelta.gt\n Input should be a valid timedelta, invalid digit in'): - SchemaValidator({'type': 'timedelta', 'gt': 'foobar'}) + validate_core_schema({'type': 'timedelta', 'gt': 'foobar'}) with pytest.raises(SchemaError, match='timedelta.le\n Input should be a valid timedelta, invalid digit in'): - SchemaValidator({'type': 'timedelta', 'le': 'foobar'}) + validate_core_schema({'type': 'timedelta', 'le': 'foobar'}) def test_dict_py(): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 1d3b694f1..67cf97d97 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -8,7 +8,7 @@ import pytest from dirty_equals import FunctionCheck -from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import Err, PyAndJson @@ -194,7 +194,7 @@ def test_allow_extra_invalid(): def test_allow_extra_wrong(): with pytest.raises(SchemaError, match="Input should be 'allow', 'forbid' or 'ignore'"): - SchemaValidator({'type': 'typed-dict', 'fields': {}, 'config': {'extra_fields_behavior': 'wrong'}}) + validate_core_schema({'type': 'typed-dict', 'fields': {}, 'config': {'extra_fields_behavior': 'wrong'}}) def test_str_config(): @@ -235,7 +235,7 @@ def test_json_error(): def test_missing_schema_key(): with pytest.raises(SchemaError, match='typed-dict.fields.x.schema\n Field required'): - SchemaValidator({'type': 'typed-dict', 'fields': {'x': {'type': 'str'}}}) + validate_core_schema({'type': 'typed-dict', 'fields': {'x': {'type': 'str'}}}) def test_fields_required_by_default(): @@ -629,10 +629,12 @@ def test_paths_allow_by_name(py_and_json: PyAndJson, input_value): def test_alias_build_error(alias_schema, error): with pytest.raises(SchemaError, match=error): SchemaValidator( - { - 'type': 'typed-dict', - 'fields': {'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'int'}, **alias_schema}}, - } + validate_core_schema( + { + 'type': 'typed-dict', + 'fields': {'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'int'}, **alias_schema}}, + } + ) ) @@ -899,7 +901,7 @@ def test_bad_default_factory(default_factory, error_message): class TestOnError: def test_on_error_bad_name(self): with pytest.raises(SchemaError, match="Input should be 'raise', 'omit' or 'default'"): - SchemaValidator( + validate_core_schema( { 'type': 'typed-dict', 'fields': { diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 05072b806..ad51fb447 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,7 +1,7 @@ import pytest from dirty_equals import IsFloat, IsInt -from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema from ..conftest import plain_repr @@ -234,7 +234,7 @@ def test_union_list_bool_int(): def test_no_choices(pydantic_version): with pytest.raises(SchemaError) as exc_info: - SchemaValidator({'type': 'union'}) + validate_core_schema({'type': 'union'}) assert str(exc_info.value) == ( 'Invalid Schema:\n' From 0a3f2b218c588bc8b9757f055dc9f2e255155b52 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 21 Sep 2023 10:03:39 -0500 Subject: [PATCH 049/550] add strict flag to validate_core_schema (#984) --- benches/main.rs | 4 ++-- python/pydantic_core/_pydantic_core.pyi | 2 +- src/validators/mod.rs | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benches/main.rs b/benches/main.rs index 8e020e269..9d46131d1 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -11,7 +11,7 @@ use _pydantic_core::{validate_core_schema, SchemaValidator}; fn build_schema_validator_with_globals(py: Python, code: &str, globals: Option<&PyDict>) -> SchemaValidator { let mut schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap(); - schema = validate_core_schema(py, schema).unwrap().extract().unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); SchemaValidator::py_new(py, schema, None).unwrap() } @@ -446,7 +446,7 @@ fn complete_model(bench: &mut Bencher) { let complete_schema = py.import("complete_schema").unwrap(); let mut schema = complete_schema.call_method0("schema").unwrap(); - schema = validate_core_schema(py, schema).unwrap().extract().unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); let validator = SchemaValidator::py_new(py, schema, None).unwrap(); let input = complete_schema.call_method0("input_data_lax").unwrap(); diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 82e6e3015..8ed3092a9 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -838,7 +838,7 @@ class TzInfo(datetime.tzinfo): def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ... def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ... -def validate_core_schema(schema: CoreSchema) -> CoreSchema: +def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> CoreSchema: """Validate a CoreSchema This currently uses lax mode for validation (i.e. will coerce strings to dates and such) but may use strict mode in the future. diff --git a/src/validators/mod.rs b/src/validators/mod.rs index c7ce0b9e9..4ee677663 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -367,10 +367,10 @@ impl<'py> SelfValidator<'py> { Ok(Self { validator }) } - pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny) -> PyResult<&'py PyAny> { + pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny, strict: Option) -> PyResult<&'py PyAny> { let mut recursion_guard = RecursionGuard::default(); let mut state = ValidationState::new( - Extra::new(None, None, None, None, InputType::Python), + Extra::new(strict, None, None, None, InputType::Python), &self.validator.definitions, &mut recursion_guard, ); @@ -408,10 +408,10 @@ impl<'py> SelfValidator<'py> { } } -#[pyfunction] -pub fn validate_core_schema<'a>(py: Python<'a>, schema: &'a PyAny) -> PyResult<&'a PyAny> { +#[pyfunction(signature = (schema, *, strict = None))] +pub fn validate_core_schema<'a>(py: Python<'a>, schema: &'a PyAny, strict: Option) -> PyResult<&'a PyAny> { let self_validator = SelfValidator::new(py)?; - self_validator.validate_schema(py, schema) + self_validator.validate_schema(py, schema, strict) } pub trait BuildValidator: Sized { From 91746c9565c43739c7a8f392209bb6d1aeb06d8a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 21 Sep 2023 17:18:39 +0100 Subject: [PATCH 050/550] make `field_name` and `data` available on `ValidationInfo` (#980) --- Cargo.lock | 2 +- Cargo.toml | 2 +- python/pydantic_core/core_schema.py | 407 +++++++----------- src/validators/function.rs | 69 +-- tests/benchmarks/complete_schema.py | 8 +- tests/benchmarks/test_micro_benchmarks.py | 29 +- tests/serializers/test_functions.py | 4 +- tests/serializers/test_other.py | 10 +- tests/test_errors.py | 52 +-- tests/test_isinstance.py | 2 +- tests/test_misc.py | 10 + tests/test_schema_functions.py | 20 +- tests/test_typing.py | 8 +- tests/test_validation_context.py | 56 +-- tests/validators/test_arguments.py | 2 +- tests/validators/test_chain.py | 30 +- tests/validators/test_dataclasses.py | 71 +-- .../validators/test_definitions_recursive.py | 6 +- tests/validators/test_function.py | 196 ++++----- tests/validators/test_list.py | 12 +- tests/validators/test_model.py | 44 +- tests/validators/test_model_fields.py | 6 +- tests/validators/test_model_init.py | 24 +- tests/validators/test_model_root.py | 10 +- tests/validators/test_nullable.py | 2 +- tests/validators/test_typed_dict.py | 4 +- tests/validators/test_with_default.py | 2 +- 27 files changed, 445 insertions(+), 643 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1027b8597..d70598ad8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.9.0" +version = "2.10.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 9e193b3d4..154d310de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.9.0" +version = "2.10.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 718bdf969..2d7061ffd 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -6,11 +6,14 @@ from __future__ import annotations as _annotations import sys +import warnings from collections.abc import Mapping from datetime import date, datetime, time, timedelta from decimal import Decimal from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union +from typing_extensions import deprecated + if sys.version_info < (3, 12): from typing_extensions import TypedDict else: @@ -177,19 +180,13 @@ def mode(self) -> Literal['python', 'json']: """The type of input data we are currently validating""" ... - -class FieldValidationInfo(ValidationInfo, Protocol): - """ - Argument passed to model field validation functions. - """ - @property def data(self) -> Dict[str, Any]: - """All of the fields and data being validated for this model.""" + """The data being validated for this model.""" ... @property - def field_name(self) -> str: + def field_name(self) -> str | None: """ The name of the current field being validated if this validator is attached to a model field. @@ -1744,25 +1741,16 @@ class NoInfoValidatorFunctionSchema(TypedDict): # (__input_value: Any, __info: ValidationInfo) -> Any -GeneralValidatorFunction = Callable[[Any, ValidationInfo], Any] - +WithInfoValidatorFunction = Callable[[Any, ValidationInfo], Any] -class GeneralValidatorFunctionSchema(TypedDict): - type: Literal['general'] - function: GeneralValidatorFunction - -# (__input_value: Any, __info: FieldValidationInfo) -> Any -FieldValidatorFunction = Callable[[Any, FieldValidationInfo], Any] - - -class FieldValidatorFunctionSchema(TypedDict): - type: Literal['field'] - function: FieldValidatorFunction +class WithInfoValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['with-info']] + function: Required[WithInfoValidatorFunction] field_name: str -ValidationFunction = Union[NoInfoValidatorFunctionSchema, FieldValidatorFunctionSchema, GeneralValidatorFunctionSchema] +ValidationFunction = Union[NoInfoValidatorFunctionSchema, WithInfoValidatorFunctionSchema] class _ValidatorFunctionSchema(TypedDict, total=False): @@ -1786,7 +1774,7 @@ def no_info_before_validator_function( serialization: SerSchema | None = None, ) -> BeforeValidatorFunctionSchema: """ - Returns a schema that calls a validator function before validating, no info is provided, e.g.: + Returns a schema that calls a validator function before validating, no `info` argument is provided, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -1820,29 +1808,29 @@ def fn(v: bytes) -> str: ) -def field_before_validator_function( - function: FieldValidatorFunction, - field_name: str, +def with_info_before_validator_function( + function: WithInfoValidatorFunction, schema: CoreSchema, *, + field_name: str | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None, ) -> BeforeValidatorFunctionSchema: """ - Returns a schema that calls a validator function before validating the function is called with information - about the field being validated, e.g.: + Returns a schema that calls a validator function before validation, the function is called with + an `info` argument, e.g.: ```py from pydantic_core import SchemaValidator, core_schema - def fn(v: bytes, info: core_schema.FieldValidationInfo) -> str: + def fn(v: bytes, info: core_schema.ValidationInfo) -> str: assert info.data is not None assert info.field_name is not None return v.decode() + 'world' - func_schema = core_schema.field_before_validator_function( - function=fn, field_name='a', schema=core_schema.str_schema() + func_schema = core_schema.with_info_before_validator_function( + function=fn, schema=core_schema.str_schema(), field_name='a' ) schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) @@ -1860,51 +1848,7 @@ def fn(v: bytes, info: core_schema.FieldValidationInfo) -> str: """ return _dict_not_none( type='function-before', - function={'type': 'field', 'function': function, 'field_name': field_name}, - schema=schema, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - -def general_before_validator_function( - function: GeneralValidatorFunction, - schema: CoreSchema, - *, - ref: str | None = None, - metadata: Any = None, - serialization: SerSchema | None = None, -) -> BeforeValidatorFunctionSchema: - """ - Returns a schema that calls a validator function before validating the provided schema, e.g.: - - ```py - from typing import Any - from pydantic_core import SchemaValidator, core_schema - - def fn(v: Any, info: core_schema.ValidationInfo) -> str: - v_str = str(v) - assert 'hello' in v_str - return v_str + 'world' - - schema = core_schema.general_before_validator_function( - function=fn, schema=core_schema.str_schema() - ) - v = SchemaValidator(schema) - assert v.validate_python(b'hello ') == "b'hello 'world" - ``` - - Args: - function: The validator function to call - schema: The schema to validate the output of the validator function - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='function-before', - function={'type': 'general', 'function': function}, + function=_dict_not_none(type='with-info', function=function, field_name=field_name), schema=schema, ref=ref, metadata=metadata, @@ -1925,7 +1869,7 @@ def no_info_after_validator_function( serialization: SerSchema | None = None, ) -> AfterValidatorFunctionSchema: """ - Returns a schema that calls a validator function after validating, no info is provided, e.g.: + Returns a schema that calls a validator function after validating, no `info` argument is provided, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -1957,29 +1901,29 @@ def fn(v: str) -> str: ) -def field_after_validator_function( - function: FieldValidatorFunction, - field_name: str, +def with_info_after_validator_function( + function: WithInfoValidatorFunction, schema: CoreSchema, *, + field_name: str | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None, ) -> AfterValidatorFunctionSchema: """ - Returns a schema that calls a validator function after validating the function is called with information - about the field being validated, e.g.: + Returns a schema that calls a validator function after validation, the function is called with + an `info` argument, e.g.: ```py from pydantic_core import SchemaValidator, core_schema - def fn(v: str, info: core_schema.FieldValidationInfo) -> str: + def fn(v: str, info: core_schema.ValidationInfo) -> str: assert info.data is not None assert info.field_name is not None return v + 'world' - func_schema = core_schema.field_after_validator_function( - function=fn, field_name='a', schema=core_schema.str_schema() + func_schema = core_schema.with_info_after_validator_function( + function=fn, schema=core_schema.str_schema(), field_name='a' ) schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) @@ -1989,57 +1933,15 @@ def fn(v: str, info: core_schema.FieldValidationInfo) -> str: Args: function: The validator function to call after the schema is validated - field_name: The name of the field - schema: The schema to validate before the validator function - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='function-after', - function={'type': 'field', 'function': function, 'field_name': field_name}, - schema=schema, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - -def general_after_validator_function( - function: GeneralValidatorFunction, - schema: CoreSchema, - *, - ref: str | None = None, - metadata: Any = None, - serialization: SerSchema | None = None, -) -> AfterValidatorFunctionSchema: - """ - Returns a schema that calls a validator function after validating the provided schema, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - def fn(v: str, info: core_schema.ValidationInfo) -> str: - assert 'hello' in v - return v + 'world' - - schema = core_schema.general_after_validator_function( - schema=core_schema.str_schema(), function=fn - ) - v = SchemaValidator(schema) - assert v.validate_python('hello ') == 'hello world' - ``` - - Args: schema: The schema to validate before the validator function - function: The validator function to call after the schema is validated + field_name: The name of the field this validators is applied to, if any ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( type='function-after', - function={'type': 'general', 'function': function}, + function=_dict_not_none(type='with-info', function=function, field_name=field_name), schema=schema, ref=ref, metadata=metadata, @@ -2062,27 +1964,16 @@ class NoInfoWrapValidatorFunctionSchema(TypedDict): # (__input_value: Any, __validator: ValidatorFunctionWrapHandler, __info: ValidationInfo) -> Any -GeneralWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, ValidationInfo], Any] - - -class GeneralWrapValidatorFunctionSchema(TypedDict): - type: Literal['general'] - function: GeneralWrapValidatorFunction +WithInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, ValidationInfo], Any] -# (__input_value: Any, __validator: ValidatorFunctionWrapHandler, __info: FieldValidationInfo) -> Any -FieldWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, FieldValidationInfo], Any] - - -class FieldWrapValidatorFunctionSchema(TypedDict): - type: Literal['field'] - function: FieldWrapValidatorFunction +class WithInfoWrapValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['with-info']] + function: Required[WithInfoWrapValidatorFunction] field_name: str -WrapValidatorFunction = Union[ - NoInfoWrapValidatorFunctionSchema, GeneralWrapValidatorFunctionSchema, FieldWrapValidatorFunctionSchema -] +WrapValidatorFunction = Union[NoInfoWrapValidatorFunctionSchema, WithInfoWrapValidatorFunctionSchema] class WrapValidatorFunctionSchema(TypedDict, total=False): @@ -2105,7 +1996,7 @@ def no_info_wrap_validator_function( """ Returns a schema which calls a function with a `validator` callable argument which can optionally be used to call inner validation with the function logic, this is much like the - "onion" implementation of middleware in many popular web frameworks, no info argument is passed, e.g.: + "onion" implementation of middleware in many popular web frameworks, no `info` argument is passed, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -2140,10 +2031,11 @@ def fn( ) -def general_wrap_validator_function( - function: GeneralWrapValidatorFunction, +def with_info_wrap_validator_function( + function: WithInfoWrapValidatorFunction, schema: CoreSchema, *, + field_name: str | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None, @@ -2151,7 +2043,7 @@ def general_wrap_validator_function( """ Returns a schema which calls a function with a `validator` callable argument which can optionally be used to call inner validation with the function logic, this is much like the - "onion" implementation of middleware in many popular web frameworks, general info is also passed, e.g.: + "onion" implementation of middleware in many popular web frameworks, an `info` argument is also passed, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -2163,7 +2055,7 @@ def fn( ) -> str: return validator(input_value=v) + 'world' - schema = core_schema.general_wrap_validator_function( + schema = core_schema.with_info_wrap_validator_function( function=fn, schema=core_schema.str_schema() ) v = SchemaValidator(schema) @@ -2173,67 +2065,14 @@ def fn( Args: function: The validator function to call schema: The schema to validate the output of the validator function + field_name: The name of the field this validators is applied to, if any ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( type='function-wrap', - function={'type': 'general', 'function': function}, - schema=schema, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - -def field_wrap_validator_function( - function: FieldWrapValidatorFunction, - field_name: str, - schema: CoreSchema, - *, - ref: str | None = None, - metadata: Any = None, - serialization: SerSchema | None = None, -) -> WrapValidatorFunctionSchema: - """ - Returns a schema applicable to **fields** - which calls a function with a `validator` callable argument which can - optionally be used to call inner validation with the function logic, this is much like the - "onion" implementation of middleware in many popular web frameworks, field info is passed, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - def fn( - v: bytes, - validator: core_schema.ValidatorFunctionWrapHandler, - info: core_schema.FieldValidationInfo, - ) -> str: - assert info.data is not None - assert info.field_name is not None - return validator(v) + 'world' - - func_schema = core_schema.field_wrap_validator_function( - function=fn, field_name='a', schema=core_schema.str_schema() - ) - schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) - - v = SchemaValidator(schema) - assert v.validate_python({'a': b'hello '}) == {'a': 'hello world'} - ``` - - Args: - function: The validator function to call - field_name: The name of the field - schema: The schema to validate the output of the validator function - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='function-wrap', - function={'type': 'field', 'function': function, 'field_name': field_name}, + function=_dict_not_none(type='with-info', function=function, field_name=field_name), schema=schema, ref=ref, metadata=metadata, @@ -2257,7 +2096,7 @@ def no_info_plain_validator_function( serialization: SerSchema | None = None, ) -> PlainValidatorFunctionSchema: """ - Returns a schema that uses the provided function for validation, no info is passed, e.g.: + Returns a schema that uses the provided function for validation, no `info` argument is passed, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -2286,15 +2125,16 @@ def fn(v: str) -> str: ) -def general_plain_validator_function( - function: GeneralValidatorFunction, +def with_info_plain_validator_function( + function: WithInfoValidatorFunction, *, + field_name: str | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None, ) -> PlainValidatorFunctionSchema: """ - Returns a schema that uses the provided function for validation, e.g.: + Returns a schema that uses the provided function for validation, an `info` argument is passed, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -2303,63 +2143,21 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str: assert 'hello' in v return v + 'world' - schema = core_schema.general_plain_validator_function(function=fn) + schema = core_schema.with_info_plain_validator_function(function=fn) v = SchemaValidator(schema) assert v.validate_python('hello ') == 'hello world' ``` Args: function: The validator function to call + field_name: The name of the field this validators is applied to, if any ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( type='function-plain', - function={'type': 'general', 'function': function}, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - -def field_plain_validator_function( - function: FieldValidatorFunction, - field_name: str, - *, - ref: str | None = None, - metadata: Any = None, - serialization: SerSchema | None = None, -) -> PlainValidatorFunctionSchema: - """ - Returns a schema that uses the provided function for validation, e.g.: - - ```py - from typing import Any - from pydantic_core import SchemaValidator, core_schema - - def fn(v: Any, info: core_schema.FieldValidationInfo) -> str: - assert info.data is not None - assert info.field_name is not None - return str(v) + 'world' - - func_schema = core_schema.field_plain_validator_function(function=fn, field_name='a') - schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) - - v = SchemaValidator(schema) - assert v.validate_python({'a': 'hello '}) == {'a': 'hello world'} - ``` - - Args: - function: The validator function to call - field_name: The name of the field - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='function-plain', - function={'type': 'field', 'function': function, 'field_name': field_name}, + function=_dict_not_none(type='with-info', function=function, field_name=field_name), ref=ref, metadata=metadata, serialization=serialization, @@ -2659,7 +2457,7 @@ def fn(v: str, info: core_schema.ValidationInfo) -> str: assert 'hello' in v return v + ' world' - fn_schema = core_schema.general_plain_validator_function(function=fn) + fn_schema = core_schema.with_info_plain_validator_function(function=fn) schema = core_schema.chain_schema( [fn_schema, fn_schema, fn_schema, core_schema.str_schema()] ) @@ -4023,3 +3821,102 @@ def definition_reference_schema( def _dict_not_none(**kwargs: Any) -> Any: return {k: v for k, v in kwargs.items() if v is not None} + + +############################################################################### +# All this stuff is deprecated by #980 and will be removed eventually +# They're kept because some code external code will be using them + + +@deprecated('`field_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.') +def field_before_validator_function(function: WithInfoValidatorFunction, field_name: str, schema: CoreSchema, **kwargs): + warnings.warn( + '`field_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.', + DeprecationWarning, + ) + return with_info_before_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.') +def general_before_validator_function(*args, **kwargs): + warnings.warn( + '`general_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.', + DeprecationWarning, + ) + return with_info_before_validator_function(*args, **kwargs) + + +@deprecated('`field_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') +def field_after_validator_function(function: WithInfoValidatorFunction, field_name: str, schema: CoreSchema, **kwargs): + warnings.warn( + '`field_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + DeprecationWarning, + ) + return with_info_after_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') +def general_after_validator_function(*args, **kwargs): + warnings.warn( + '`with_info_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + DeprecationWarning, + ) + return with_info_after_validator_function(*args, **kwargs) + + +@deprecated('`field_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.') +def field_wrap_validator_function( + function: WithInfoWrapValidatorFunction, field_name: str, schema: CoreSchema, **kwargs +): + warnings.warn( + '`field_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.', + DeprecationWarning, + ) + return with_info_wrap_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.') +def general_wrap_validator_function(*args, **kwargs): + warnings.warn( + '`general_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.', + DeprecationWarning, + ) + return with_info_wrap_validator_function(*args, **kwargs) + + +@deprecated('`field_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.') +def field_plain_validator_function(function: WithInfoValidatorFunction, field_name: str, **kwargs): + warnings.warn( + '`field_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.', + DeprecationWarning, + ) + return with_info_plain_validator_function(function, field_name=field_name, **kwargs) + + +@deprecated('`general_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.') +def general_plain_validator_function(*args, **kwargs): + warnings.warn( + '`general_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.', + DeprecationWarning, + ) + return with_info_plain_validator_function(*args, **kwargs) + + +_deprecated_import_lookup = { + 'FieldValidationInfo': ValidationInfo, + 'FieldValidatorFunction': WithInfoValidatorFunction, + 'GeneralValidatorFunction': WithInfoValidatorFunction, + 'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction, +} + + +def __getattr__(attr_name: str) -> object: + new_attr = _deprecated_import_lookup.get(attr_name) + if new_attr is None: + raise AttributeError(f"module 'pydantic_core' has no attribute '{attr_name}'") + else: + import warnings + + msg = f'`{attr_name}` is deprecated, use `{new_attr.__name__}` instead.' + warnings.warn(msg, DeprecationWarning, stacklevel=1) + return new_attr diff --git a/src/validators/function.rs b/src/validators/function.rs index 6334a5e16..be0d6374f 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::{PyAssertionError, PyAttributeError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -8,7 +8,7 @@ use crate::errors::{ }; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::tools::{function_name, py_err, SchemaDict}; +use crate::tools::{function_name, py_err, safe_repr, SchemaDict}; use crate::PydanticUseDefault; use super::generator::InternalValidator; @@ -28,20 +28,14 @@ fn destructure_function_schema(schema: &PyDict) -> PyResult { let func_dict: &PyDict = schema.get_as_req(intern!(schema.py(), "function"))?; let function: &PyAny = func_dict.get_as_req(intern!(schema.py(), "function"))?; let func_type: &str = func_dict.get_as_req(intern!(schema.py(), "type"))?; - let (is_field_validator, info_arg) = match func_type { - "field" => (true, true), - "general" => (false, true), - "no-info" => (false, false), + let info_arg = match func_type { + "with-info" => true, + "no-info" => false, _ => unreachable!(), }; - let field_name: Option> = match is_field_validator { - true => Some( - func_dict - .get_as_req::<&PyString>(intern!(schema.py(), "field_name"))? - .into(), - ), - false => None, - }; + let field_name = func_dict + .get_as::<&PyString>(intern!(schema.py(), "field_name"))? + .map(Into::into); Ok(FunctionInfo { function: function.into(), field_name, @@ -525,15 +519,12 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> } } -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", get_all)] pub struct ValidationInfo { - #[pyo3(get)] config: PyObject, - #[pyo3(get)] context: Option, data: Option>, field_name: Option>, - #[pyo3(get)] mode: InputType, } @@ -551,40 +542,22 @@ impl ValidationInfo { #[pymethods] impl ValidationInfo { - #[getter] - fn get_data(&self, py: Python) -> PyResult> { - match (&self.data, &self.field_name) { - (Some(data), Some(_)) => Ok(data.clone_ref(py)), - _ => Err(PyAttributeError::new_err("No attribute named 'data'")), - } - } - - #[getter] - fn get_field_name(&self) -> PyResult> { - match self.field_name { - Some(ref field_name) => Ok(field_name.clone()), - None => Err(PyAttributeError::new_err("No attribute named 'field_name'")), - } - } - fn __repr__(&self, py: Python) -> PyResult { let context = match self.context { - Some(ref context) => context.as_ref(py).repr()?.extract()?, - None => "None", + Some(ref context) => safe_repr(context.as_ref(py)), + None => "None".into(), }; let config = self.config.as_ref(py).repr()?; - let mut s = if self.field_name.is_some() { - format!("FieldValidationInfo(config={config}, context={context}") - } else { - format!("ValidationInfo(config={config}, context={context}") + let data = match self.data { + Some(ref data) => safe_repr(data.as_ref(py)), + None => "None".into(), }; - if let Ok(data) = self.get_data(py) { - s += &format!(", data={}", data.as_ref(py).repr()?); - } - if let Ok(field_name) = self.get_field_name() { - s += &format!(", field_name='{field_name}'"); - } - s += ")"; - Ok(s) + let field_name = match self.field_name { + Some(ref field_name) => safe_repr(field_name.as_ref(py)), + None => "None".into(), + }; + Ok(format!( + "ValidationInfo(config={config}, context={context}, data={data}, field_name={field_name})" + )) } } diff --git a/tests/benchmarks/complete_schema.py b/tests/benchmarks/complete_schema.py index d4eff16b2..8c24b9b7e 100644 --- a/tests/benchmarks/complete_schema.py +++ b/tests/benchmarks/complete_schema.py @@ -158,7 +158,7 @@ def wrap_function(input_value, validator, info): 'type': 'typed-dict-field', 'schema': { 'type': 'function-before', - 'function': {'type': 'general', 'function': append_func}, + 'function': {'type': 'with-info', 'function': append_func}, 'schema': {'type': 'str'}, }, }, @@ -166,7 +166,7 @@ def wrap_function(input_value, validator, info): 'type': 'typed-dict-field', 'schema': { 'type': 'function-after', - 'function': {'type': 'general', 'function': append_func}, + 'function': {'type': 'with-info', 'function': append_func}, 'schema': {'type': 'str'}, }, }, @@ -174,7 +174,7 @@ def wrap_function(input_value, validator, info): 'type': 'typed-dict-field', 'schema': { 'type': 'function-wrap', - 'function': {'type': 'general', 'function': wrap_function}, + 'function': {'type': 'with-info', 'function': wrap_function}, 'schema': {'type': 'str'}, }, }, @@ -182,7 +182,7 @@ def wrap_function(input_value, validator, info): 'type': 'typed-dict-field', 'schema': { 'type': 'function-plain', - 'function': {'type': 'general', 'function': append_func}, + 'function': {'type': 'with-info', 'function': append_func}, }, }, }, diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 7e6799a74..cb08436c6 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -791,7 +791,7 @@ def test_dont_raise_error(benchmark): def f(input_value, info): return input_value - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) @benchmark def t(): @@ -803,7 +803,7 @@ def test_dont_raise_error_no_info(benchmark): def f(input_value): return input_value - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'no-info', 'function': f}}) + v = SchemaValidator(core_schema.no_info_plain_validator_function(f)) @benchmark def t(): @@ -815,7 +815,7 @@ def test_raise_error_value_error(benchmark): def f(input_value, info): raise ValueError('this is a custom error') - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) @benchmark def t(): @@ -832,7 +832,7 @@ def test_raise_error_custom(benchmark): def f(input_value, info): raise PydanticCustomError('my_error', 'this is a custom error {foo}', {'foo': 'FOOBAR'}) - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) @benchmark def t(): @@ -927,10 +927,7 @@ def test_chain_list(benchmark): validator = SchemaValidator( { 'type': 'chain', - 'steps': [ - {'type': 'str'}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}}, - ], + 'steps': [{'type': 'str'}, core_schema.with_info_plain_validator_function(lambda v, info: Decimal(v))], } ) assert validator.validate_python('42.42') == Decimal('42.42') @@ -944,7 +941,7 @@ def test_chain_function(benchmark): { 'type': 'function-after', 'schema': {'type': 'str'}, - 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}, + 'function': {'type': 'with-info', 'function': lambda v, info: Decimal(v)}, } ) assert validator.validate_python('42.42') == Decimal('42.42') @@ -959,8 +956,8 @@ def test_chain_two_functions(benchmark): 'type': 'chain', 'steps': [ {'type': 'str'}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: v * 2}}, + core_schema.with_info_plain_validator_function(lambda v, info: Decimal(v)), + core_schema.with_info_plain_validator_function(lambda v, info: v * 2), ], } ) @@ -977,9 +974,9 @@ def test_chain_nested_functions(benchmark): 'schema': { 'type': 'function-after', 'schema': {'type': 'str'}, - 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}, + 'function': {'type': 'with-info', 'function': lambda v, info: Decimal(v)}, }, - 'function': {'type': 'general', 'function': lambda v, info: v * 2}, + 'function': {'type': 'with-info', 'function': lambda v, info: v * 2}, } ) assert validator.validate_python('42.42') == Decimal('84.84') @@ -1002,7 +999,7 @@ def generator_gen_python(v, validator, info): @pytest.mark.benchmark(group='generator') def test_generator_python(benchmark): - schema = core_schema.general_wrap_validator_function(generator_gen_python, {'type': 'int'}) + schema = core_schema.with_info_wrap_validator_function(generator_gen_python, {'type': 'int'}) v = SchemaValidator(schema) input_value = tuple(range(100)) @@ -1301,7 +1298,7 @@ def test_tagged_union_int_keys_json(benchmark): @skip_wasm_deep_stack @pytest.mark.benchmark(group='field_function_validator') def test_field_function_validator(benchmark) -> None: - def f(v: int, info: core_schema.FieldValidationInfo) -> int: + def f(v: int, info: core_schema.ValidationInfo) -> int: assert info.field_name == 'x' return v + 1 @@ -1309,7 +1306,7 @@ def f(v: int, info: core_schema.FieldValidationInfo) -> int: limit = pydantic_core._pydantic_core._recursion_limit - 3 for _ in range(limit): - schema = core_schema.field_after_validator_function(f, 'x', schema) + schema = core_schema.with_info_after_validator_function(f, schema, field_name='x') schema = core_schema.typed_dict_schema({'x': core_schema.typed_dict_field(schema)}) diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index 660b5e2f2..318254602 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -588,7 +588,7 @@ def test_function_after_preserves_wrapped_serialization(): def f(value, _info): return value - s = SchemaSerializer(core_schema.general_after_validator_function(f, core_schema.int_schema())) + s = SchemaSerializer(core_schema.with_info_after_validator_function(f, core_schema.int_schema())) with pytest.warns(UserWarning, match='Expected `int` but got `str` - serialized value may not be as expected'): assert s.to_python('abc') == 'abc' @@ -597,7 +597,7 @@ def test_function_wrap_preserves_wrapped_serialization(): def f(value, handler, _info): return handler(value) - s = SchemaSerializer(core_schema.general_wrap_validator_function(f, core_schema.int_schema())) + s = SchemaSerializer(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) with pytest.warns(UserWarning, match='Expected `int` but got `str` - serialized value may not be as expected'): assert s.to_python('abc') == 'abc' diff --git a/tests/serializers/test_other.py b/tests/serializers/test_other.py index 06183d773..42549b0b5 100644 --- a/tests/serializers/test_other.py +++ b/tests/serializers/test_other.py @@ -16,20 +16,24 @@ def test_chain(): def test_function_plain(): - s = SchemaSerializer(core_schema.general_plain_validator_function(lambda v, info: v + 1)) + s = SchemaSerializer(core_schema.with_info_plain_validator_function(lambda v, info: v + 1)) # can't infer the type from plain function validators # insert_assert(plain_repr(s)) assert plain_repr(s) == 'SchemaSerializer(serializer=Any(AnySerializer),definitions=[])' def test_function_before(): - s = SchemaSerializer(core_schema.general_before_validator_function(lambda v, info: v + 1, core_schema.int_schema())) + s = SchemaSerializer( + core_schema.with_info_before_validator_function(lambda v, info: v + 1, core_schema.int_schema()) + ) # insert_assert(plain_repr(s)) assert plain_repr(s) == 'SchemaSerializer(serializer=Int(IntSerializer),definitions=[])' def test_function_after(): - s = SchemaSerializer(core_schema.general_after_validator_function(lambda v, info: v + 1, core_schema.int_schema())) + s = SchemaSerializer( + core_schema.with_info_after_validator_function(lambda v, info: v + 1, core_schema.int_schema()) + ) # insert_assert(plain_repr(s)) assert plain_repr(s) == 'SchemaSerializer(serializer=Int(IntSerializer),definitions=[])' diff --git a/tests/test_errors.py b/tests/test_errors.py index 15d2c78af..293880977 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -64,7 +64,7 @@ def test_pydantic_value_error_usage(): def f(input_value, info): raise PydanticCustomError('my_error', 'this is a custom error {foo} {bar}', {'foo': 'FOOBAR', 'bar': 42}) - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) with pytest.raises(ValidationError) as exc_info: v.validate_python(42) @@ -84,7 +84,7 @@ def test_pydantic_value_error_invalid_dict(): def my_function(input_value, info): raise PydanticCustomError('my_error', 'this is a custom error {foo}', {(): 'foobar'}) - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': my_function}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(my_function)) with pytest.raises(ValidationError) as exc_info: v.validate_python(42) @@ -102,7 +102,7 @@ def test_pydantic_value_error_invalid_type(): def f(input_value, info): raise PydanticCustomError('my_error', 'this is a custom error {foo}', [('foo', 123)]) - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) with pytest.raises(TypeError, match="argument 'context': 'list' object cannot be converted to 'PyDict'"): v.validate_python(42) @@ -118,9 +118,7 @@ def validate(self, input_value, info): return f'{input_value} {self.foo} {self.bar}' c = CustomValidator() - v = SchemaValidator( - {'type': 'function-plain', 'metadata': {'instance': c}, 'function': {'type': 'general', 'function': c.validate}} - ) + v = SchemaValidator(core_schema.with_info_plain_validator_function(c.validate, metadata={'instance': c})) c.foo += 1 assert v.validate_python('input value') == 'input value 43 before' @@ -139,12 +137,7 @@ def validate(self, input_value, info): c = CustomValidator() v = SchemaValidator( - { - 'type': 'function-after', - 'metadata': {'instance': c}, - 'function': {'type': 'general', 'function': c.validate}, - 'schema': {'type': 'str'}, - } + core_schema.with_info_after_validator_function(c.validate, core_schema.str_schema(), metadata={'instance': c}) ) c.foo += 1 @@ -175,9 +168,7 @@ def test_pydantic_error_type_raise_no_ctx(): def f(input_value, info): raise PydanticKnownError('finite_number') - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) with pytest.raises(ValidationError) as exc_info: v.validate_python(4) @@ -196,9 +187,7 @@ def test_pydantic_error_type_raise_ctx(extra: dict): def f(input_value, info): raise PydanticKnownError('greater_than', ctx) - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) with pytest.raises(ValidationError) as exc_info: v.validate_python(4) @@ -213,9 +202,7 @@ def test_pydantic_error_type_raise_custom_no_ctx(ctx: Optional[dict]): def f(input_value, info): raise PydanticKnownError('int_type', ctx) - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) expect_ctx = {'ctx': {}} if ctx is not None else {} @@ -236,9 +223,7 @@ def test_pydantic_custom_error_type_raise_custom_ctx(extra: dict): def f(input_value, info): raise PydanticCustomError('my_error', 'my message with {val}', ctx) - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) with pytest.raises(ValidationError) as exc_info: v.validate_python(4) @@ -253,9 +238,7 @@ def test_pydantic_custom_error_type_raise_custom_no_ctx(ctx: Optional[dict]): def f(input_value, info): raise PydanticCustomError('my_error', 'my message', ctx) - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) expect_ctx = {'ctx': {}} if ctx is not None else {} @@ -422,28 +405,19 @@ def test_pydantic_value_error_plain(py_and_json: PyAndJson): def f(input_value, info): raise PydanticCustomError - v = py_and_json({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = py_and_json(core_schema.with_info_plain_validator_function(f)) with pytest.raises(TypeError, match='missing 2 required positional arguments'): v.validate_test('4') @pytest.mark.parametrize('exception', [PydanticOmit(), PydanticOmit]) def test_list_omit_exception(py_and_json: PyAndJson, exception): - def f(input_value, info): + def f(input_value): if input_value % 2 == 0: raise exception return input_value - v = py_and_json( - { - 'type': 'list', - 'items_schema': { - 'type': 'function-after', - 'schema': {'type': 'int'}, - 'function': {'type': 'general', 'function': f}, - }, - } - ) + v = py_and_json(core_schema.list_schema(core_schema.no_info_after_validator_function(f, core_schema.int_schema()))) assert v.validate_test([1, 2, '3', '4']) == [1, 3] diff --git a/tests/test_isinstance.py b/tests/test_isinstance.py index 44c3f2390..38b6f0160 100644 --- a/tests/test_isinstance.py +++ b/tests/test_isinstance.py @@ -56,7 +56,7 @@ def omit(v, info): else: return v - v = py_and_json(core_schema.general_plain_validator_function(omit)) + v = py_and_json(core_schema.with_info_plain_validator_function(omit)) assert v.validate_test('foo') == 'foo' if v.validator_type == 'python': assert v.isinstance_test('foo') is True diff --git a/tests/test_misc.py b/tests/test_misc.py index 98aac82da..806c4cca9 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -204,3 +204,13 @@ def test_unicode_error_input_repr() -> None: actual = repr(exc_info.value).split('For further information visit ')[0].strip() assert expected == actual + + +def test_core_schema_import_field_validation_info(): + with pytest.warns(DeprecationWarning, match='`FieldValidationInfo` is deprecated, use `ValidationInfo` instead.'): + core_schema.FieldValidationInfo + + +def test_core_schema_import_missing(): + with pytest.raises(AttributeError, match="module 'pydantic_core' has no attribute 'foobar'"): + core_schema.foobar diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index c119ca567..d4b53cbe4 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -108,32 +108,36 @@ def args(*args, **kwargs): {'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'int'}}, ), ( - core_schema.general_before_validator_function, + core_schema.with_info_before_validator_function, args(val_function, {'type': 'int'}), { 'type': 'function-before', - 'function': {'type': 'general', 'function': val_function}, + 'function': {'type': 'with-info', 'function': val_function}, 'schema': {'type': 'int'}, }, ), ( - core_schema.general_after_validator_function, + core_schema.with_info_after_validator_function, args(val_function, {'type': 'int'}), { 'type': 'function-after', - 'function': {'type': 'general', 'function': val_function}, + 'function': {'type': 'with-info', 'function': val_function}, 'schema': {'type': 'int'}, }, ), ( - core_schema.general_wrap_validator_function, + core_schema.with_info_wrap_validator_function, args(val_function, {'type': 'int'}), - {'type': 'function-wrap', 'function': {'type': 'general', 'function': val_function}, 'schema': {'type': 'int'}}, + { + 'type': 'function-wrap', + 'function': {'type': 'with-info', 'function': val_function}, + 'schema': {'type': 'int'}, + }, ), ( - core_schema.general_plain_validator_function, + core_schema.with_info_plain_validator_function, args(val_function), - {'type': 'function-plain', 'function': {'type': 'general', 'function': val_function}}, + core_schema.with_info_plain_validator_function(val_function), ), ( core_schema.with_default_schema, diff --git a/tests/test_typing.py b/tests/test_typing.py index 55c3732e8..0d527c619 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -111,11 +111,11 @@ def test_schema_typing() -> None: SchemaValidator(schema) schema: CoreSchema = { 'type': 'function-wrap', - 'function': {'type': 'general', 'function': wrap_validator}, + 'function': {'type': 'with-info', 'function': wrap_validator, 'field_name': 'foobar'}, 'schema': {'type': 'str'}, } SchemaValidator(schema) - schema: CoreSchema = {'type': 'function-plain', 'function': {'type': 'general', 'function': validator}} + schema: CoreSchema = core_schema.with_info_plain_validator_function(validator) SchemaValidator(schema) schema: CoreSchema = { 'type': 'definitions', @@ -189,7 +189,7 @@ def test_correct_function_signature() -> None: def my_validator(value: Any, info: Any) -> str: return str(value) - v = SchemaValidator(core_schema.general_plain_validator_function(my_validator)) + v = SchemaValidator(core_schema.with_info_plain_validator_function(my_validator)) assert v.validate_python(1) == '1' @@ -197,7 +197,7 @@ def test_wrong_function_signature() -> None: def wrong_validator(value: Any) -> Any: return value - v = SchemaValidator(core_schema.general_plain_validator_function(wrong_validator)) # type: ignore + v = SchemaValidator(core_schema.with_info_plain_validator_function(wrong_validator)) # type: ignore # use this instead of pytest.raises since pyright complains about input when pytest isn't installed try: diff --git a/tests/test_validation_context.py b/tests/test_validation_context.py index 838bd66e2..5d71f5b55 100644 --- a/tests/test_validation_context.py +++ b/tests/test_validation_context.py @@ -1,6 +1,6 @@ import pytest -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from .conftest import PyAndJson @@ -9,9 +9,7 @@ def test_after(py_and_json: PyAndJson): def f(input_value, info): return input_value + f'| context: {info.context}' - v = py_and_json( - {'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = py_and_json(core_schema.with_info_after_validator_function(f, core_schema.str_schema())) assert v.validate_test('foobar') == 'foobar| context: None' assert v.validate_test('foobar', None, {1: 10}) == 'foobar| context: {1: 10}' @@ -23,9 +21,7 @@ def f(input_value, info): info.context['foo'] = input_value return input_value - v = py_and_json( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = py_and_json(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) mutable_context = {} assert v.validate_test('foobar', None, mutable_context) == 'foobar' assert mutable_context == {'foo': 'foobar'} @@ -41,19 +37,12 @@ def f2(input_value, info): return input_value + f'| context: {info.context}' v = py_and_json( - { - 'type': 'typed-dict', - 'fields': { - 'f1': { - 'type': 'typed-dict-field', - 'schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f1}}, - }, - 'f2': { - 'type': 'typed-dict-field', - 'schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f2}}, - }, - }, - } + core_schema.typed_dict_schema( + { + 'f1': core_schema.typed_dict_field(core_schema.with_info_plain_validator_function(f1)), + 'f2': core_schema.typed_dict_field(core_schema.with_info_plain_validator_function(f2)), + } + ) ) assert v.validate_test({'f1': '1', 'f2': '2'}, None, {'x': 'y'}) == { @@ -66,9 +55,7 @@ def test_wrap(py_and_json: PyAndJson): def f(input_value, validator, info): return validator(input_value) + f'| context: {info.context}' - v = py_and_json( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = py_and_json(core_schema.with_info_wrap_validator_function(f, core_schema.str_schema())) assert v.validate_test('foobar') == 'foobar| context: None' assert v.validate_test('foobar', None, {1: 10}) == 'foobar| context: {1: 10}' @@ -81,9 +68,7 @@ def f(input_value, validator, info): raise ValueError('wrong') return validator(input_value) - v = py_and_json( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = py_and_json(core_schema.with_info_wrap_validator_function(f, core_schema.str_schema())) assert v.validate_python('foobar', None, {}) == 'foobar' @@ -115,19 +100,12 @@ def f2(input_value, info): return input_value + f'| context: {info.context}' v = SchemaValidator( - { - 'type': 'model-fields', - 'fields': { - 'f1': { - 'type': 'model-field', - 'schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f1}}, - }, - 'f2': { - 'type': 'model-field', - 'schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f2}}, - }, - }, - } + core_schema.model_fields_schema( + { + 'f1': core_schema.model_field(core_schema.with_info_plain_validator_function(f1)), + 'f2': core_schema.model_field(core_schema.with_info_plain_validator_function(f2)), + } + ) ) m1, model_extra, fields_set = v.validate_python({'f1': '1', 'f2': '2'}, strict=None, context={'x': 'y'}) diff --git a/tests/validators/test_arguments.py b/tests/validators/test_arguments.py index 0f634eb54..4ef581b47 100644 --- a/tests/validators/test_arguments.py +++ b/tests/validators/test_arguments.py @@ -621,7 +621,7 @@ def test_internal_error(py_and_json: PyAndJson): { 'name': 'b', 'mode': 'positional_only', - 'schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': double_or_bust}}, + 'schema': core_schema.with_info_plain_validator_function(double_or_bust), }, ], } diff --git a/tests/validators/test_chain.py b/tests/validators/test_chain.py index 36561b735..0ba59084d 100644 --- a/tests/validators/test_chain.py +++ b/tests/validators/test_chain.py @@ -2,7 +2,7 @@ import pytest -from pydantic_core import SchemaError, SchemaValidator, ValidationError +from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema from ..conftest import PyAndJson @@ -11,10 +11,7 @@ def test_chain(): validator = SchemaValidator( { 'type': 'chain', - 'steps': [ - {'type': 'str'}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}}, - ], + 'steps': [{'type': 'str'}, core_schema.with_info_plain_validator_function(lambda v, info: Decimal(v))], } ) @@ -27,10 +24,10 @@ def test_chain_many(): { 'type': 'chain', 'steps': [ - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-1'}}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-2'}}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-3'}}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-4'}}, + core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-1'), + core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-2'), + core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-3'), + core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-4'), ], } ) @@ -66,7 +63,7 @@ def test_json(py_and_json: PyAndJson, input_value, expected): 'type': 'chain', 'steps': [ {'type': 'union', 'choices': [{'type': 'str'}, {'type': 'float'}]}, - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: Decimal(v)}}, + core_schema.with_info_plain_validator_function(lambda v, info: Decimal(v)), ], } ) @@ -80,17 +77,17 @@ def test_flatten(): { 'type': 'chain', 'steps': [ - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-1'}}, + core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-1'), { 'type': 'chain', 'steps': [ { 'type': 'function-plain', - 'function': {'type': 'general', 'function': lambda v, info: f'{v}-2'}, + 'function': {'type': 'with-info', 'function': lambda v, info: f'{v}-2'}, }, { 'type': 'function-plain', - 'function': {'type': 'general', 'function': lambda v, info: f'{v}-3'}, + 'function': {'type': 'with-info', 'function': lambda v, info: f'{v}-3'}, }, ], }, @@ -109,12 +106,7 @@ def test_chain_empty(): def test_chain_one(): validator = SchemaValidator( - { - 'type': 'chain', - 'steps': [ - {'type': 'function-plain', 'function': {'type': 'general', 'function': lambda v, info: f'{v}-1'}} - ], - } + {'type': 'chain', 'steps': [core_schema.with_info_plain_validator_function(lambda v, info: f'{v}-1')]} ) assert validator.validate_python('input') == 'input-1' assert validator.title == 'function-plain[()]' diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 895f4f2bd..9c2b93dfe 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -493,7 +493,7 @@ class Foo: b: str @classmethod - def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: + def validate_b(cls, v: str, info: core_schema.ValidationInfo) -> str: assert v == 'hello' assert info.field_name == 'b' assert info.data == {'a': 1} @@ -507,7 +507,9 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_after_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_after_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -526,7 +528,7 @@ class Foo: b: str @classmethod - def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> str: + def validate_b(cls, v: bytes, info: core_schema.ValidationInfo) -> str: assert v == b'hello' assert info.field_name == 'b' assert info.data == {'a': 1} @@ -539,7 +541,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> str: [ core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( - name='b', schema=core_schema.field_plain_validator_function(Foo.validate_b, 'b') + name='b', schema=core_schema.with_info_plain_validator_function(Foo.validate_b, field_name='b') ), ], ), @@ -558,7 +560,7 @@ class Foo: b: str @classmethod - def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: + def validate_b(cls, v: bytes, info: core_schema.ValidationInfo) -> bytes: assert v == b'hello' assert info.field_name == 'b' assert info.data == {'a': 1} @@ -572,7 +574,9 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_before_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_before_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -592,7 +596,7 @@ class Foo: @classmethod def validate_b( - cls, v: bytes, nxt: core_schema.ValidatorFunctionWrapHandler, info: core_schema.FieldValidationInfo + cls, v: bytes, nxt: core_schema.ValidatorFunctionWrapHandler, info: core_schema.ValidationInfo ) -> str: assert v == b'hello' v = nxt(v) @@ -609,7 +613,9 @@ def validate_b( core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_wrap_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_wrap_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -629,7 +635,7 @@ class Foo: @classmethod def validate_b( - cls, v: bytes, nxt: core_schema.ValidatorFunctionWrapHandler, info: core_schema.FieldValidationInfo + cls, v: bytes, nxt: core_schema.ValidatorFunctionWrapHandler, info: core_schema.ValidationInfo ) -> bytes: assert v == b'hello' assert info.field_name == 'b' @@ -644,7 +650,9 @@ def validate_b( core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_wrap_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_wrap_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -870,7 +878,10 @@ def func(x, info): [ core_schema.dataclass_field('field_a', core_schema.str_schema()), core_schema.dataclass_field( - 'field_b', core_schema.field_after_validator_function(func, 'field_b', core_schema.int_schema()) + 'field_b', + core_schema.with_info_after_validator_function( + func, core_schema.int_schema(), field_name='field_b' + ), ), core_schema.dataclass_field('field_c', core_schema.int_schema()), ], @@ -883,14 +894,14 @@ def func(x, info): assert m.field_a == 'x' assert m.field_b == 246 assert m.field_c == 456 - assert calls == ["FieldValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')"] + assert calls == ["ValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')"] v.validate_assignment(m, 'field_b', '111') assert m.field_b == 222 assert calls == [ - "FieldValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')", - "FieldValidationInfo(config=None, context=None, data={'field_a': 'x', 'field_c': 456}, field_name='field_b')", + "ValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')", + "ValidationInfo(config=None, context=None, data={'field_a': 'x', 'field_c': 456}, field_name='field_b')", ] @@ -1271,7 +1282,7 @@ class Foo: b: str @classmethod - def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: + def validate_b(cls, v: bytes, info: core_schema.ValidationInfo) -> bytes: assert v == b'hello' assert info.field_name == 'b' assert info.data == {'a': 1} @@ -1285,7 +1296,9 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_before_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_before_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -1306,7 +1319,7 @@ class Foo: b: str @classmethod - def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: + def validate_b(cls, v: str, info: core_schema.ValidationInfo) -> str: assert v == 'hello' assert info.field_name == 'b' assert info.data == {'a': 1} @@ -1320,7 +1333,9 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), core_schema.dataclass_field( name='b', - schema=core_schema.field_after_validator_function(Foo.validate_b, 'b', core_schema.str_schema()), + schema=core_schema.with_info_after_validator_function( + Foo.validate_b, core_schema.str_schema(), field_name='b' + ), ), ], ), @@ -1515,9 +1530,15 @@ def _wrap_validator(cls, v, validator, info): field_schema = core_schema.int_schema() if validator == 'field': - field_schema = core_schema.field_before_validator_function(Dataclass._validator, 'a', field_schema) - field_schema = core_schema.field_wrap_validator_function(Dataclass._wrap_validator, 'a', field_schema) - field_schema = core_schema.field_after_validator_function(Dataclass._validator, 'a', field_schema) + field_schema = core_schema.with_info_before_validator_function( + Dataclass._validator, field_schema, field_name='a' + ) + field_schema = core_schema.with_info_wrap_validator_function( + Dataclass._wrap_validator, field_schema, field_name='a' + ) + field_schema = core_schema.with_info_after_validator_function( + Dataclass._validator, field_schema, field_name='a' + ) dataclass_schema = core_schema.dataclass_schema( Dataclass, @@ -1526,9 +1547,11 @@ def _wrap_validator(cls, v, validator, info): ) if validator == 'dataclass': - dataclass_schema = core_schema.general_before_validator_function(Dataclass._validator, dataclass_schema) - dataclass_schema = core_schema.general_wrap_validator_function(Dataclass._wrap_validator, dataclass_schema) - dataclass_schema = core_schema.general_after_validator_function(Dataclass._validator, dataclass_schema) + dataclass_schema = core_schema.with_info_before_validator_function(Dataclass._validator, dataclass_schema) + dataclass_schema = core_schema.with_info_wrap_validator_function( + Dataclass._wrap_validator, dataclass_schema + ) + dataclass_schema = core_schema.with_info_after_validator_function(Dataclass._validator, dataclass_schema) # If any of the Rust validators don't implement traversal properly, # there will be an undetectable cycle created by this assignment diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index f23999cfa..b836eb7a1 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -519,7 +519,7 @@ def wrap_func(input_value, validator, info): core_schema.definitions_schema( core_schema.definition_reference_schema('wrapper'), [ - core_schema.general_wrap_validator_function( + core_schema.with_info_wrap_validator_function( wrap_func, core_schema.tuple_positional_schema( [ @@ -653,7 +653,7 @@ def f(input_value, info): [ core_schema.union_schema( [ - core_schema.general_after_validator_function( + core_schema.with_info_after_validator_function( f, core_schema.definition_reference_schema('root-schema') ), core_schema.int_schema(), @@ -701,7 +701,7 @@ def f(input_value, info): [ core_schema.union_schema( [ - core_schema.general_before_validator_function( + core_schema.with_info_before_validator_function( f, core_schema.definition_reference_schema('root-schema') ) ], diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 1b35a9c25..9f94ceb1b 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -2,7 +2,7 @@ import platform import re from copy import deepcopy -from typing import Any, Dict, List, Type, Union +from typing import Any, Dict, List, Type import pytest from dirty_equals import HasRepr @@ -12,11 +12,11 @@ from ..conftest import plain_repr -def deepcopy_info(info: Union[core_schema.ValidationInfo, core_schema.FieldValidationInfo]) -> Dict[str, Any]: +def deepcopy_info(info: core_schema.ValidationInfo) -> Dict[str, Any]: return { 'context': deepcopy(info.context), - 'data': deepcopy(getattr(info, 'data', None)), - 'field_name': deepcopy(getattr(info, 'field_name', None)), + 'data': deepcopy(info.data), + 'field_name': deepcopy(info.field_name), 'config': deepcopy(info.config), } @@ -25,9 +25,7 @@ def test_function_before(): def f(input_value, _info): return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -36,9 +34,7 @@ def test_function_before_no_info(): def f(input_value): return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'no-info', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.no_info_before_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -47,9 +43,7 @@ def test_function_before_raise(): def f(input_value, info): raise ValueError('foobar') - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) with pytest.raises(ValidationError) as exc_info: assert v.validate_python('input value') == 'input value Changed' @@ -72,7 +66,7 @@ def my_function(input_value, info): v = SchemaValidator( { 'type': 'function-before', - 'function': {'type': 'general', 'function': my_function}, + 'function': {'type': 'with-info', 'function': my_function}, 'schema': {'type': 'str', 'max_length': 5}, } ) @@ -107,7 +101,7 @@ def my_function(input_value, info): v = SchemaValidator( { 'type': 'function-before', - 'function': {'type': 'general', 'function': my_function}, + 'function': {'type': 'with-info', 'function': my_function}, 'schema': {'type': 'str', 'max_length': 5}, }, config, @@ -126,7 +120,7 @@ def f(input_value, info): v = SchemaValidator( { 'type': 'function-before', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'typed-dict', 'fields': {'my_field': {'type': 'typed-dict-field', 'schema': {'type': 'str', 'max_length': 5}}}, @@ -151,10 +145,10 @@ def f(input_value, info): @pytest.mark.parametrize( 'config,kwargs,expected_repr', [ - (None, {}, 'ValidationInfo(config=None, context=None)'), - (None, {'context': {1: 2}}, 'ValidationInfo(config=None, context={1: 2})'), - (None, {'context': None}, 'ValidationInfo(config=None, context=None)'), - ({'title': 'hello'}, {}, "ValidationInfo(config={'title': 'hello'}, context=None)"), + (None, {}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), + (None, {'context': {1: 2}}, 'ValidationInfo(config=None, context={1: 2}, data=None, field_name=None)'), + (None, {'context': None}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), + ({'title': 'hello'}, {}, "ValidationInfo(config={'title': 'hello'}, context=None, data=None, field_name=None)"), ], ) def test_val_info_repr(config, kwargs, expected_repr): @@ -163,9 +157,7 @@ def f(input_value, info: core_schema.ValidationInfo): assert str(info) == expected_repr return input_value - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}}, config - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema()), config) assert v.validate_python('input value', **kwargs) == 'input value' @@ -174,9 +166,7 @@ def test_function_wrap(): def f(input_value, validator, info): return validator(input_value=input_value) + ' Changed' - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -185,9 +175,7 @@ def test_function_wrap_no_info(): def f(input_value, validator): return validator(input_value=input_value) + ' Changed' - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'no-info', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.no_info_wrap_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -197,9 +185,7 @@ def f(input_value, validator, info): assert repr(validator) == str(validator) return plain_repr(validator) - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.str_schema())) assert ( v.validate_python('input value') @@ -211,9 +197,7 @@ def test_function_wrap_str(): def f(input_value, validator, info): return plain_repr(validator) - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.str_schema())) assert ( v.validate_python('input value') @@ -223,9 +207,7 @@ def f(input_value, validator, info): def test_function_wrap_not_callable(): with pytest.raises(SchemaError, match='function-wrap.function.typed-dict.function\n Input should be callable'): - validate_core_schema( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': []}, 'schema': {'type': 'str'}} - ) + validate_core_schema(core_schema.with_info_wrap_validator_function([], core_schema.str_schema())) with pytest.raises(SchemaError, match='function-wrap.function\n Field required'): validate_core_schema({'type': 'function-wrap', 'schema': {'type': 'str'}}) @@ -240,9 +222,7 @@ def f(input_value, validator, info): assert str(e).startswith('1 validation error for ValidatorCallable\n') raise e - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) assert v.validate_python('42') == 84 with pytest.raises(ValidationError) as exc_info: @@ -274,9 +254,7 @@ def f(input_value, validator, info): assert str(e).startswith('1 validation error for ValidatorCallable\n') raise e - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}}, config - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema()), config) with pytest.raises( ValidationError, @@ -289,9 +267,7 @@ def test_function_wrap_location(): def f(input_value, validator, info): return validator(input_value, outer_location='foo') + 2 - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) assert v.validate_python(4) == 6 with pytest.raises(ValidationError) as exc_info: @@ -311,9 +287,7 @@ def test_function_wrap_invalid_location(): def f(input_value, validator, info): return validator(input_value, ('4',)) + 2 - v = SchemaValidator( - {'type': 'function-wrap', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'int'}} - ) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) with pytest.raises(TypeError, match='^outer_location must be a str or int$'): v.validate_python(4) @@ -323,9 +297,7 @@ def test_function_after(): def f(input_value, _info): return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_after_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -334,9 +306,7 @@ def test_function_no_info(): def f(input_value): return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-after', 'function': {'type': 'no-info', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.no_info_after_validator_function(f, core_schema.str_schema())) assert v.validate_python('input value') == 'input value Changed' @@ -345,9 +315,7 @@ def test_function_after_raise(): def f(input_value, info): raise ValueError('foobar') - v = SchemaValidator( - {'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_after_validator_function(f, core_schema.str_schema())) with pytest.raises(ValidationError) as exc_info: assert v.validate_python('input value') == 'input value Changed' @@ -375,9 +343,7 @@ def test_function_after_error_hide_input(config, input_str): def f(input_value, info): raise ValueError('foobar') - v = SchemaValidator( - {'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}}, config - ) + v = SchemaValidator(core_schema.with_info_after_validator_function(f, core_schema.str_schema()), config) with pytest.raises(ValidationError, match=re.escape(f'Value error, foobar [{input_str}]')): v.validate_python('input value') @@ -399,7 +365,7 @@ def f(input_value, info): 'type': 'typed-dict-field', 'schema': { 'type': 'function-after', - 'function': {'type': 'field', 'function': f, 'field_name': 'test_field'}, + 'function': {'type': 'with-info', 'function': f, 'field_name': 'test_field'}, 'schema': {'type': 'str'}, }, } @@ -420,9 +386,7 @@ def f(input_value, info: core_schema.ValidationInfo): f_kwargs = deepcopy_info(info) return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_after_validator_function(f, core_schema.str_schema())) assert v.validate_python(b'abc') == 'abc Changed' assert f_kwargs == {'data': None, 'config': None, 'context': None, 'field_name': None} @@ -432,7 +396,7 @@ def test_function_plain(): def f(input_value, _info): return input_value * 2 - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'general', 'function': f}}) + v = SchemaValidator(core_schema.with_info_plain_validator_function(f)) assert v.validate_python(1) == 2 assert v.validate_python('x') == 'xx' @@ -442,7 +406,7 @@ def test_function_plain_no_info(): def f(input_value): return input_value * 2 - v = SchemaValidator({'type': 'function-plain', 'function': {'type': 'no-info', 'function': f}}) + v = SchemaValidator(core_schema.no_info_plain_validator_function(f)) assert v.validate_python(1) == 2 assert v.validate_python('x') == 'xx' @@ -453,7 +417,7 @@ def test_plain_with_schema(): validate_core_schema( { 'type': 'function-plain', - 'function': {'type': 'general', 'function': lambda x: x}, + 'function': {'type': 'with-info', 'function': lambda x: x}, 'schema': {'type': 'str'}, } ) @@ -497,9 +461,7 @@ def test_function_wrong_sig(): def f(input_value): return input_value + ' Changed' - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) # exception messages differ between python and pypy if platform.python_implementation() == 'PyPy': @@ -525,7 +487,7 @@ def __validate__(cls, input_value, info): v = SchemaValidator( { 'type': 'function-after', - 'function': {'type': 'general', 'function': Foobar.__validate__}, + 'function': {'type': 'with-info', 'function': Foobar.__validate__}, 'schema': {'type': 'str'}, } ) @@ -550,9 +512,7 @@ def test_raise_assertion_error(): def f(input_value, info): raise AssertionError('foobar') - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) with pytest.raises(ValidationError) as exc_info: v.validate_python('input value') @@ -572,9 +532,7 @@ def test_raise_assertion_error_plain(): def f(input_value, info): raise AssertionError - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) with pytest.raises(ValidationError) as exc_info: v.validate_python('input value') @@ -599,9 +557,7 @@ def __str__(self): def f(input_value, info): raise MyError() - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) with pytest.raises(RuntimeError, match='internal error'): v.validate_python('input value') @@ -611,9 +567,7 @@ def test_raise_type_error(): def f(input_value, info): raise TypeError('foobar') - v = SchemaValidator( - {'type': 'function-before', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}} - ) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.str_schema())) with pytest.raises(TypeError, match='^foobar$'): v.validate_python('input value') @@ -623,11 +577,11 @@ def test_model_field_before_validator() -> None: class Model: x: str - def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: + def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: assert info.field_name == 'x' assert info.data == {} - assert repr(info) == "FieldValidationInfo(config=None, context=None, data={}, field_name='x')" - assert str(info) == "FieldValidationInfo(config=None, context=None, data={}, field_name='x')" + assert repr(info) == "ValidationInfo(config=None, context=None, data={}, field_name='x')" + assert str(info) == "ValidationInfo(config=None, context=None, data={}, field_name='x')" assert isinstance(input_value, bytes) return f'input: {input_value.decode()}' @@ -637,7 +591,7 @@ def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: core_schema.model_fields_schema( { 'x': core_schema.model_field( - core_schema.field_before_validator_function(f, 'x', core_schema.str_schema()) + core_schema.with_info_before_validator_function(f, core_schema.str_schema(), field_name='x') ) } ), @@ -651,7 +605,7 @@ def test_model_field_after_validator() -> None: class Model: x: str - def f(input_value: str, info: core_schema.FieldValidationInfo) -> Any: + def f(input_value: str, info: core_schema.ValidationInfo) -> Any: assert info.field_name == 'x' assert info.data == {} assert isinstance(input_value, str) @@ -663,7 +617,7 @@ def f(input_value: str, info: core_schema.FieldValidationInfo) -> Any: core_schema.model_fields_schema( { 'x': core_schema.model_field( - core_schema.field_after_validator_function(f, 'x', core_schema.str_schema()) + core_schema.with_info_after_validator_function(f, core_schema.str_schema(), field_name='x') ) } ), @@ -677,7 +631,7 @@ def test_model_field_plain_validator() -> None: class Model: x: str - def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: + def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: assert info.field_name == 'x' assert info.data == {} assert isinstance(input_value, bytes) @@ -687,7 +641,7 @@ def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: core_schema.model_schema( Model, core_schema.model_fields_schema( - {'x': core_schema.model_field(core_schema.field_plain_validator_function(f, 'x'))} + {'x': core_schema.model_field(core_schema.with_info_plain_validator_function(f, field_name='x'))} ), ) ) @@ -699,9 +653,7 @@ def test_model_field_wrap_validator() -> None: class Model: x: str - def f( - input_value: Any, val: core_schema.ValidatorFunctionWrapHandler, info: core_schema.FieldValidationInfo - ) -> Any: + def f(input_value: Any, val: core_schema.ValidatorFunctionWrapHandler, info: core_schema.ValidationInfo) -> Any: assert info.field_name == 'x' assert info.data == {} assert isinstance(input_value, bytes) @@ -713,7 +665,7 @@ def f( core_schema.model_fields_schema( { 'x': core_schema.model_field( - core_schema.field_wrap_validator_function(f, 'x', core_schema.str_schema()) + core_schema.with_info_wrap_validator_function(f, core_schema.str_schema(), field_name='x') ) } ), @@ -723,13 +675,9 @@ def f( assert v.validate_python({'x': b'foo'}).x == 'input: foo' -def check_that_info_has_no_model_data(info: core_schema.ValidationInfo) -> None: - with pytest.raises(AttributeError, match="No attribute named 'field_name'"): - info.field_name # type: ignore[attr-defined] - with pytest.raises(AttributeError, match="No attribute named 'data'"): - info.data # type: ignore[attr-defined] - assert not hasattr(info, 'field_name') - assert not hasattr(info, 'data') +def check_info_field_name_none(info: core_schema.ValidationInfo) -> None: + assert info.field_name is None + assert info.data == {} def test_non_model_field_before_validator_tries_to_access_field_info() -> None: @@ -737,7 +685,7 @@ class Model: x: str def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: - check_that_info_has_no_model_data(info) + check_info_field_name_none(info) assert isinstance(input_value, bytes) return f'input: {input_value.decode()}' @@ -747,7 +695,7 @@ def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: core_schema.model_fields_schema( { 'x': core_schema.model_field( - core_schema.general_before_validator_function(f, core_schema.str_schema()) + core_schema.with_info_before_validator_function(f, core_schema.str_schema()) ) } ), @@ -762,7 +710,7 @@ class Model: x: str def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: - check_that_info_has_no_model_data(info) + check_info_field_name_none(info) return f'input: {input_value}' v = SchemaValidator( @@ -771,7 +719,7 @@ def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: core_schema.model_fields_schema( { 'x': core_schema.model_field( - core_schema.general_after_validator_function(f, core_schema.str_schema()) + core_schema.with_info_after_validator_function(f, core_schema.str_schema()) ) } ), @@ -786,7 +734,7 @@ class Model: x: str def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: - check_that_info_has_no_model_data(info) + check_info_field_name_none(info) assert isinstance(input_value, bytes) return f'input: {input_value.decode()}' @@ -794,7 +742,7 @@ def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: core_schema.model_schema( Model, core_schema.model_fields_schema( - {'x': core_schema.model_field(core_schema.general_plain_validator_function(f))} + {'x': core_schema.model_field(core_schema.with_info_plain_validator_function(f))} ), ) ) @@ -808,14 +756,18 @@ class Model: x: str def f(input_value: Any, val: core_schema.ValidatorFunctionWrapHandler, info: core_schema.ValidationInfo) -> Any: - check_that_info_has_no_model_data(info) + check_info_field_name_none(info) return f'input: {val(input_value)}' v = SchemaValidator( core_schema.model_schema( Model, core_schema.model_fields_schema( - {'x': core_schema.model_field(core_schema.general_wrap_validator_function(f, core_schema.str_schema()))} + { + 'x': core_schema.model_field( + core_schema.with_info_wrap_validator_function(f, core_schema.str_schema()) + ) + } ), ) ) @@ -826,7 +778,7 @@ def f(input_value: Any, val: core_schema.ValidatorFunctionWrapHandler, info: cor def test_typed_dict_data() -> None: info_stuff = None - def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: + def f(input_value: Any, info: core_schema.ValidationInfo) -> Any: nonlocal info_stuff info_stuff = {'field_name': info.field_name, 'data': info.data.copy()} assert isinstance(input_value, str) @@ -838,7 +790,7 @@ def f(input_value: Any, info: core_schema.FieldValidationInfo) -> Any: 'a': core_schema.typed_dict_field(core_schema.int_schema()), 'b': core_schema.typed_dict_field(core_schema.int_schema()), 'c': core_schema.typed_dict_field( - core_schema.field_after_validator_function(f, 'c', core_schema.str_schema()) + core_schema.with_info_after_validator_function(f, core_schema.str_schema(), field_name='c') ), } ) @@ -894,7 +846,7 @@ def f(input_value: Any, *args: Any) -> Any: Model, { 'type': f'function-{mode}', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': core_schema.model_fields_schema( { 'x': core_schema.model_field(core_schema.str_schema()), @@ -922,7 +874,7 @@ def f(v: Any, info: core_schema.ValidationInfo) -> Any: calls.append(info.mode) return v - v = SchemaValidator(core_schema.general_before_validator_function(f, core_schema.int_schema())) + v = SchemaValidator(core_schema.with_info_before_validator_function(f, core_schema.int_schema())) assert v.validate_python(1) == 1 assert calls == ['python'] calls.clear() @@ -930,7 +882,7 @@ def f(v: Any, info: core_schema.ValidationInfo) -> Any: assert calls == ['json'] calls.clear() - v = SchemaValidator(core_schema.general_after_validator_function(f, core_schema.int_schema())) + v = SchemaValidator(core_schema.with_info_after_validator_function(f, core_schema.int_schema())) assert v.validate_python(1) == 1 assert calls == ['python'] calls.clear() @@ -942,7 +894,7 @@ def f_w(v: Any, handler: core_schema.ValidatorFunctionWrapHandler, info: core_sc calls.append(info.mode) return handler(v) - v = SchemaValidator(core_schema.general_wrap_validator_function(f_w, core_schema.int_schema())) + v = SchemaValidator(core_schema.with_info_wrap_validator_function(f_w, core_schema.int_schema())) assert v.validate_python(1) == 1 assert calls == ['python'] calls.clear() @@ -961,8 +913,8 @@ def sample_repr(v: Any, info: core_schema.ValidationInfo) -> Any: v = SchemaValidator( core_schema.chain_schema( [ - core_schema.general_plain_validator_function(sample_repr), - core_schema.field_plain_validator_function(sample_repr, field_name='x'), + core_schema.with_info_plain_validator_function(sample_repr), + core_schema.with_info_plain_validator_function(sample_repr, field_name='x'), ] ) ) @@ -975,8 +927,8 @@ def __repr__(self) -> str: # insert_assert(reprs) assert reprs == [ - 'ValidationInfo(config=None, context=None)', - "FieldValidationInfo(config=None, context=None, field_name='x')", + 'ValidationInfo(config=None, context=None, data=None, field_name=None)', + "ValidationInfo(config=None, context=None, data=None, field_name='x')", ] diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index d3cfaa528..05e5ed832 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -236,9 +236,7 @@ def test_list_function(): def f(input_value, info): return input_value * 2 - v = SchemaValidator( - {'type': 'list', 'items_schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f}}} - ) + v = SchemaValidator({'type': 'list', 'items_schema': core_schema.with_info_plain_validator_function(f)}) assert v.validate_python([1, 2, 3]) == [2, 4, 6] @@ -247,9 +245,7 @@ def test_list_function_val_error(): def f(input_value, info): raise ValueError(f'error {input_value}') - v = SchemaValidator( - {'type': 'list', 'items_schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f}}} - ) + v = SchemaValidator({'type': 'list', 'items_schema': core_schema.with_info_plain_validator_function(f)}) with pytest.raises(ValidationError) as exc_info: v.validate_python([1, 2]) @@ -275,9 +271,7 @@ def test_list_function_internal_error(): def f(input_value, info): raise RuntimeError(f'error {input_value}') - v = SchemaValidator( - {'type': 'list', 'items_schema': {'type': 'function-plain', 'function': {'type': 'general', 'function': f}}} - ) + v = SchemaValidator({'type': 'list', 'items_schema': core_schema.with_info_plain_validator_function(f)}) with pytest.raises(RuntimeError, match='^error 1$') as exc_info: v.validate_python([1, 2]) diff --git a/tests/validators/test_model.py b/tests/validators/test_model.py index 7d5c9ff89..69b6a06ec 100644 --- a/tests/validators/test_model.py +++ b/tests/validators/test_model.py @@ -121,7 +121,7 @@ def f( schema = core_schema.model_schema( MyModel, - core_schema.general_wrap_validator_function( + core_schema.with_info_wrap_validator_function( f, core_schema.model_fields_schema({'field_a': core_schema.model_field(core_schema.int_schema())}) ), ) @@ -155,7 +155,7 @@ def f(input_value: Dict[str, Any], info: core_schema.ValidationInfo): schema = core_schema.model_schema( MyModel, - core_schema.general_before_validator_function( + core_schema.with_info_before_validator_function( f, core_schema.model_fields_schema({'field_a': core_schema.model_field(core_schema.int_schema())}) ), ) @@ -227,7 +227,7 @@ def f(input_value, info): 'cls': MyModel, 'schema': { 'type': f'function-{mode}', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'model-fields', 'fields': {'field_a': {'type': 'model-field', 'schema': {'type': 'str'}}}, @@ -244,13 +244,7 @@ class MyModel: def f(input_value): return input_value, {1: 2}, {'field_a'} - v = SchemaValidator( - { - 'type': 'model', - 'cls': MyModel, - 'schema': {'type': 'function-plain', 'function': {'type': 'no-info', 'function': f}}, - } - ) + v = SchemaValidator({'type': 'model', 'cls': MyModel, 'schema': core_schema.no_info_plain_validator_function(f)}) m = v.validate_python({'field_a': 'test'}) assert isinstance(m, MyModel) assert m.__dict__ == {'field_a': 'test'} @@ -363,7 +357,7 @@ def f(input_value, info): 'cls': MyModel, 'schema': { 'type': 'function-after', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'model-fields', 'fields': {'field_a': {'type': 'model-field', 'schema': {'type': 'str'}}}, @@ -996,7 +990,9 @@ def func(x, info): { 'field_a': core_schema.model_field(core_schema.str_schema()), 'field_b': core_schema.model_field( - core_schema.field_after_validator_function(func, 'field_b', core_schema.int_schema()) + core_schema.with_info_after_validator_function( + func, core_schema.int_schema(), field_name='field_b' + ) ), 'field_c': core_schema.model_field(core_schema.int_schema()), } @@ -1009,14 +1005,14 @@ def func(x, info): assert m.field_b == 246 assert m.field_c == 456 assert m.__pydantic_fields_set__ == {'field_a', 'field_b', 'field_c'} - assert calls == ["FieldValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')"] + assert calls == ["ValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')"] v.validate_assignment(m, 'field_b', '111') assert m.field_b == 222 assert calls == [ - "FieldValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')", - "FieldValidationInfo(config=None, context=None, data={'field_a': 'x', 'field_c': 456}, field_name='field_b')", + "ValidationInfo(config=None, context=None, data={'field_a': 'x'}, field_name='field_b')", + "ValidationInfo(config=None, context=None, data={'field_a': 'x', 'field_c': 456}, field_name='field_b')", ] @@ -1085,19 +1081,19 @@ class MyModel: 'function_schema,call1, call2', [ ( - core_schema.general_after_validator_function, - (({'a': 1, 'b': 2}, None, {'b'}), 'ValidationInfo(config=None, context=None)'), - (({'a': 10, 'b': 2}, None, {'a'}), 'ValidationInfo(config=None, context=None)'), + core_schema.with_info_after_validator_function, + (({'a': 1, 'b': 2}, None, {'b'}), 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), + (({'a': 10, 'b': 2}, None, {'a'}), 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), ), ( - core_schema.general_before_validator_function, - ({'b': 2}, 'ValidationInfo(config=None, context=None)'), - ({'a': 10, 'b': 2}, 'ValidationInfo(config=None, context=None)'), + core_schema.with_info_before_validator_function, + ({'b': 2}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), + ({'a': 10, 'b': 2}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), ), ( - core_schema.general_wrap_validator_function, - ({'b': 2}, 'ValidationInfo(config=None, context=None)'), - ({'a': 10, 'b': 2}, 'ValidationInfo(config=None, context=None)'), + core_schema.with_info_wrap_validator_function, + ({'b': 2}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), + ({'a': 10, 'b': 2}, 'ValidationInfo(config=None, context=None, data=None, field_name=None)'), ), ], ) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 05201759a..e21f50008 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -280,7 +280,7 @@ def func_b(input_value, info): 'type': 'model-field', 'schema': { 'type': 'function-after', - 'function': {'type': 'general', 'function': func_a}, + 'function': {'type': 'with-info', 'function': func_a}, 'schema': {'type': 'str'}, }, }, @@ -288,7 +288,7 @@ def func_b(input_value, info): 'type': 'model-field', 'schema': { 'type': 'function-after', - 'function': {'type': 'general', 'function': func_b}, + 'function': {'type': 'with-info', 'function': func_b}, 'schema': {'type': 'int'}, }, }, @@ -1608,7 +1608,7 @@ def wrap_function(input_value, validator, info): 'on_error': 'raise', 'schema': { 'type': 'function-wrap', - 'function': {'type': 'general', 'function': wrap_function}, + 'function': {'type': 'with-info', 'function': wrap_function}, 'schema': {'type': 'str'}, }, }, diff --git a/tests/validators/test_model_init.py b/tests/validators/test_model_init.py index fb8ef9f43..5521f8da4 100644 --- a/tests/validators/test_model_init.py +++ b/tests/validators/test_model_init.py @@ -100,7 +100,7 @@ def f(input_value, _info): v = SchemaValidator( { 'type': 'function-before', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'model', 'cls': MyModel, @@ -136,7 +136,7 @@ def f(input_value, _info): v = SchemaValidator( { 'type': 'function-after', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'model', 'cls': MyModel, @@ -174,7 +174,7 @@ def f(input_value, handler, _info): v = SchemaValidator( { 'type': 'function-wrap', - 'function': {'type': 'general', 'function': f}, + 'function': {'type': 'with-info', 'function': f}, 'schema': { 'type': 'model', 'cls': MyModel, @@ -442,18 +442,24 @@ def _wrap_validator(cls, v, validator, info): field_schema = core_schema.int_schema() if validator == 'field': - field_schema = core_schema.field_before_validator_function(Model._validator, 'a', field_schema) - field_schema = core_schema.field_wrap_validator_function(Model._wrap_validator, 'a', field_schema) - field_schema = core_schema.field_after_validator_function(Model._validator, 'a', field_schema) + field_schema = core_schema.with_info_before_validator_function( + Model._validator, field_schema, field_name='a' + ) + field_schema = core_schema.with_info_wrap_validator_function( + Model._wrap_validator, field_schema, field_name='a' + ) + field_schema = core_schema.with_info_after_validator_function( + Model._validator, field_schema, field_name='a' + ) model_schema = core_schema.model_schema( Model, core_schema.model_fields_schema({'a': core_schema.model_field(field_schema)}) ) if validator == 'model': - model_schema = core_schema.general_before_validator_function(Model._validator, model_schema) - model_schema = core_schema.general_wrap_validator_function(Model._wrap_validator, model_schema) - model_schema = core_schema.general_after_validator_function(Model._validator, model_schema) + model_schema = core_schema.with_info_before_validator_function(Model._validator, model_schema) + model_schema = core_schema.with_info_wrap_validator_function(Model._wrap_validator, model_schema) + model_schema = core_schema.with_info_after_validator_function(Model._validator, model_schema) # If any of the Rust validators don't implement traversal properly, # there will be an undetectable cycle created by this assignment diff --git a/tests/validators/test_model_root.py b/tests/validators/test_model_root.py index db9d91eb5..8daa326ea 100644 --- a/tests/validators/test_model_root.py +++ b/tests/validators/test_model_root.py @@ -139,20 +139,22 @@ def f(input_value: str, info): v = SchemaValidator( core_schema.model_schema( - RootModel, core_schema.field_after_validator_function(f, 'root', core_schema.str_schema()), root_model=True + RootModel, + core_schema.with_info_after_validator_function(f, core_schema.str_schema(), field_name='root'), + root_model=True, ) ) m = v.validate_python('foobar', context='call 1') assert isinstance(m, RootModel) assert m.root == 'foobar validated' - assert call_infos == ["FieldValidationInfo(config=None, context='call 1', field_name='root')"] + assert call_infos == ["ValidationInfo(config=None, context='call 1', data=None, field_name='root')"] m2 = v.validate_assignment(m, 'root', 'baz', context='assignment call') assert m2 is m assert m.root == 'baz validated' assert call_infos == [ - "FieldValidationInfo(config=None, context='call 1', field_name='root')", - "FieldValidationInfo(config=None, context='assignment call', field_name='root')", + "ValidationInfo(config=None, context='call 1', data=None, field_name='root')", + "ValidationInfo(config=None, context='assignment call', data=None, field_name='root')", ] diff --git a/tests/validators/test_nullable.py b/tests/validators/test_nullable.py index fc20aa414..a74d56138 100644 --- a/tests/validators/test_nullable.py +++ b/tests/validators/test_nullable.py @@ -47,7 +47,7 @@ def fn(): def validate(v, info): return v - schema = core_schema.general_plain_validator_function(validate) + schema = core_schema.with_info_plain_validator_function(validate) schema = core_schema.nullable_schema(schema) # If any of the Rust validators don't implement traversal properly, diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 67cf97d97..5f0729d25 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -1066,7 +1066,7 @@ def wrap_function(input_value, validator, info): 'on_error': 'raise', 'schema': { 'type': 'function-wrap', - 'function': {'type': 'general', 'function': wrap_function}, + 'function': {'type': 'with-info', 'function': wrap_function}, 'schema': {'type': 'str'}, }, }, @@ -1173,7 +1173,7 @@ def fn(): def validate(v, info): return v - schema = core_schema.general_plain_validator_function(validate) + schema = core_schema.with_info_plain_validator_function(validate) schema = core_schema.typed_dict_schema( {'f': core_schema.typed_dict_field(schema)}, extra_behavior='allow', extras_schema=schema ) diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index e0ebd2fb3..808e4807d 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -633,7 +633,7 @@ class Defaulted(int): def _validator(cls, v, info): return Defaulted(v) - schema = core_schema.general_plain_validator_function(Defaulted._validator) + schema = core_schema.with_info_plain_validator_function(Defaulted._validator) schema = core_schema.with_default_schema(schema, default=Defaulted(0)) # If any of the Rust validators don't implement traversal properly, From 0ebd60703814b3f5a03eb483e6ec41df15650dec Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 13:48:43 +0100 Subject: [PATCH 051/550] Bump smallvec from 1.11.0 to 1.11.1 (#988) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d70598ad8..cfcbf6591 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -434,9 +434,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "speedate" diff --git a/Cargo.toml b/Cargo.toml index 154d310de..f28ec2944 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_ enum_dispatch = "0.3.8" serde = { version = "1.0.188", features = ["derive"] } speedate = "0.12.0" -smallvec = "1.11.0" +smallvec = "1.11.1" ahash = "0.8.0" url = "2.4.1" # idna is already required by url, added here to be explicit From e610984bd12a673f62892f6d2ab42d344e2714ec Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 25 Sep 2023 13:48:53 +0100 Subject: [PATCH 052/550] Bump ruff from 0.0.290 to 0.0.291 (#987) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index ff8f85657..48e914dbb 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.2 pyright==1.1.327 -ruff==0.0.290 +ruff==0.0.291 mypy==1.5.1 From 8435153f2c125b0ea4dbfc9c352bfcd02ae75181 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:59:40 +0100 Subject: [PATCH 053/550] improve quality of too short / too long error messages (#990) --- src/errors/types.rs | 7 +-- src/input/return_enums.rs | 72 +++++++++++------------------- src/validators/generator.rs | 2 +- src/validators/list.rs | 2 +- src/validators/tuple.rs | 20 +++++---- tests/validators/test_frozenset.py | 6 +-- tests/validators/test_generator.py | 8 ++-- tests/validators/test_list.py | 3 +- tests/validators/test_set.py | 6 +-- tests/validators/test_tuple.py | 6 +-- 10 files changed, 56 insertions(+), 76 deletions(-) diff --git a/src/errors/types.rs b/src/errors/types.rs index da4d5fdd7..d537158ba 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -227,7 +227,7 @@ error_types! { TooLong { field_type: {ctx_type: String, ctx_fn: field_from_context}, max_length: {ctx_type: usize, ctx_fn: field_from_context}, - actual_length: {ctx_type: usize, ctx_fn: field_from_context}, + actual_length: {ctx_type: Option, ctx_fn: field_from_context}, }, // --------------------- // generic collection and iteration errors @@ -630,7 +630,7 @@ impl ErrorType { .. } => { let expected_plural = plural_s(*min_length); - to_string_render!(tmpl, field_type, min_length, actual_length, expected_plural) + to_string_render!(tmpl, field_type, min_length, actual_length, expected_plural,) } Self::TooLong { field_type, @@ -639,7 +639,8 @@ impl ErrorType { .. } => { let expected_plural = plural_s(*max_length); - to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural) + let actual_length = actual_length.map_or(Cow::Borrowed("more"), |v| Cow::Owned(v.to_string())); + to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,) } Self::IterationError { error, .. } => render!(tmpl, error), Self::StringTooShort { min_length, .. } => to_string_render!(tmpl, min_length), diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index d7d5a7f3b..c492f40f0 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -106,54 +106,33 @@ struct MaxLengthCheck<'a, INPUT> { max_length: Option, field_type: &'a str, input: &'a INPUT, - known_input_length: usize, + actual_length: Option, } impl<'a, INPUT: Input<'a>> MaxLengthCheck<'a, INPUT> { - fn new(max_length: Option, field_type: &'a str, input: &'a INPUT, known_input_length: usize) -> Self { + fn new(max_length: Option, field_type: &'a str, input: &'a INPUT, actual_length: Option) -> Self { Self { current_length: 0, max_length, field_type, input, - known_input_length, + actual_length, } } fn incr(&mut self) -> ValResult<'a, ()> { - match self.max_length { - Some(max_length) => { - self.current_length += 1; - if self.current_length > max_length { - let biggest_length = if self.known_input_length > self.current_length { - self.known_input_length - } else { - self.current_length - }; - return Err(ValError::new( - ErrorType::TooLong { - field_type: self.field_type.to_string(), - max_length, - actual_length: biggest_length, - context: None, - }, - self.input, - )); - } - } - None => { - self.current_length += 1; - if self.current_length > self.known_input_length { - return Err(ValError::new( - ErrorType::TooLong { - field_type: self.field_type.to_string(), - max_length: self.known_input_length, - actual_length: self.current_length, - context: None, - }, - self.input, - )); - } + if let Some(max_length) = self.max_length { + self.current_length += 1; + if self.current_length > max_length { + return Err(ValError::new( + ErrorType::TooLong { + field_type: self.field_type.to_string(), + max_length, + actual_length: self.actual_length, + context: None, + }, + self.input, + )); } } Ok(()) @@ -255,13 +234,15 @@ fn validate_iter_to_set<'a, 's>( Ok(item) => { set.build_add(item)?; if let Some(max_length) = max_length { - let actual_length = set.build_len(); - if actual_length > max_length { + if set.build_len() > max_length { return Err(ValError::new( ErrorType::TooLong { field_type: field_type.to_string(), max_length, - actual_length, + // The logic here is that it doesn't matter how many elements the + // input actually had; all we know is it had more than the allowed + // number of deduplicated elements. + actual_length: None, context: None, }, input, @@ -335,10 +316,9 @@ impl<'a> GenericIterable<'a> { validator: &'s CombinedValidator, state: &mut ValidationState, ) -> ValResult<'a, Vec> { - let capacity = self - .generic_len() - .unwrap_or_else(|| max_length.unwrap_or(DEFAULT_CAPACITY)); - let max_length_check = MaxLengthCheck::new(max_length, field_type, input, capacity); + let actual_length = self.generic_len(); + let capacity = actual_length.unwrap_or(DEFAULT_CAPACITY); + let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length); macro_rules! validate { ($iter:expr) => { @@ -394,10 +374,8 @@ impl<'a> GenericIterable<'a> { field_type: &'static str, max_length: Option, ) -> ValResult<'a, Vec> { - let capacity = self - .generic_len() - .unwrap_or_else(|| max_length.unwrap_or(DEFAULT_CAPACITY)); - let max_length_check = MaxLengthCheck::new(max_length, field_type, input, capacity); + let actual_length = self.generic_len(); + let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length); match self { GenericIterable::List(collection) => { diff --git a/src/validators/generator.rs b/src/validators/generator.rs index bf6d009e1..0cff7e28e 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -145,7 +145,7 @@ impl ValidatorIterator { ErrorType::TooLong { field_type: "Generator".to_string(), max_length, - actual_length: index + 1, + actual_length: None, context: None, }, $iter.input_as_error_value(py), diff --git a/src/validators/list.rs b/src/validators/list.rs index 49eaede67..ffd7a118e 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -58,7 +58,7 @@ macro_rules! length_check { crate::errors::ErrorType::TooLong { field_type: $field_type.to_string(), max_length, - actual_length, + actual_length: Some(actual_length), context: None, }, $input, diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 07887fddb..5c2c09bec 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -137,8 +137,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, extras_validator: &Option>, items_validators: &[CombinedValidator], collection_iter: &mut T, - collection_len: Option, - expected_length: usize, + actual_length: Option, ) -> ValResult<'data, ()> { for (index, validator) in items_validators.iter().enumerate() { match collection_iter.next() { @@ -167,7 +166,7 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, errors.extend( line_errors .into_iter() - .map(|err| err.with_outer_location((index + expected_length).into())), + .map(|err| err.with_outer_location((index + items_validators.len()).into())), ); } Err(ValError::Omit) => (), @@ -177,8 +176,8 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, errors.push(ValLineError::new( ErrorType::TooLong { field_type: "Tuple".to_string(), - max_length: expected_length, - actual_length: collection_len.unwrap_or(index), + max_length: items_validators.len(), + actual_length, context: None, }, input, @@ -204,8 +203,12 @@ impl Validator for TuplePositionalValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_tuple(state.strict_or(self.strict))?; - let expected_length = self.items_validators.len(); - let collection_len = collection.generic_len(); + let actual_length = collection.generic_len(); + let expected_length = if self.extras_validator.is_some() { + actual_length.unwrap_or(self.items_validators.len()) + } else { + self.items_validators.len() + }; let mut output: Vec = Vec::with_capacity(expected_length); let mut errors: Vec = Vec::new(); @@ -221,8 +224,7 @@ impl Validator for TuplePositionalValidator { &self.extras_validator, &self.items_validators, &mut $collection_iter, - collection_len, - expected_length, + actual_length, )? }}; } diff --git a/tests/validators/test_frozenset.py b/tests/validators/test_frozenset.py index 2e53b38cb..9ec2f10d5 100644 --- a/tests/validators/test_frozenset.py +++ b/tests/validators/test_frozenset.py @@ -148,12 +148,12 @@ def generate_repeats(): ( {'max_length': 3}, {1, 2, 3, 4}, - Err('Frozenset should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Frozenset should have at most 3 items after validation, not more [type=too_long,'), ), ( {'items_schema': {'type': 'int'}, 'max_length': 3}, {1, 2, 3, 4}, - Err('Frozenset should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Frozenset should have at most 3 items after validation, not more [type=too_long,'), ), # length check after set creation ({'max_length': 3}, [1, 1, 2, 2, 3, 3], {1, 2, 3}), @@ -161,7 +161,7 @@ def generate_repeats(): ( {'max_length': 3}, infinite_generator(), - Err('Frozenset should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Frozenset should have at most 3 items after validation, not more [type=too_long,'), ), ], ) diff --git a/tests/validators/test_generator.py b/tests/validators/test_generator.py index 0bb637443..f8d7fd429 100644 --- a/tests/validators/test_generator.py +++ b/tests/validators/test_generator.py @@ -118,9 +118,9 @@ def test_too_long(py_and_json: PyAndJson): { 'type': 'too_long', 'loc': (), - 'msg': 'Generator should have at most 2 items after validation, not 3', + 'msg': 'Generator should have at most 2 items after validation, not more', 'input': [1, 2, 3], - 'ctx': {'field_type': 'Generator', 'max_length': 2, 'actual_length': 3}, + 'ctx': {'field_type': 'Generator', 'max_length': 2, 'actual_length': None}, } ] @@ -167,8 +167,8 @@ def test_generator_too_long(): 'type': 'too_long', 'loc': (), 'input': HasRepr(IsStr(regex='')), - 'msg': 'Generator should have at most 2 items after validation, not 3', - 'ctx': {'field_type': 'Generator', 'max_length': 2, 'actual_length': 3}, + 'msg': 'Generator should have at most 2 items after validation, not more', + 'ctx': {'field_type': 'Generator', 'max_length': 2, 'actual_length': None}, } ] diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index 05e5ed832..6f5ca8a34 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -160,14 +160,13 @@ def test_list_error(input_value, index): ( {'max_length': 44}, infinite_generator(), - Err('List should have at most 44 items after validation, not 45 [type=too_long,'), + Err('List should have at most 44 items after validation, not more [type=too_long,'), ), ( {'max_length': 4, 'items_schema': {'type': 'int'}}, [0, 1, 2, 3, 4, 5, 6, 7, 8], Err('List should have at most 4 items after validation, not 9 [type=too_long,'), ), - ({}, infinite_generator(), Err('List should have at most 10 items after validation, not 11 [type=too_long,')), ], ) def test_list_length_constraints(kwargs: Dict[str, Any], input_value, expected): diff --git a/tests/validators/test_set.py b/tests/validators/test_set.py index 0b7bbe373..a6babb80d 100644 --- a/tests/validators/test_set.py +++ b/tests/validators/test_set.py @@ -126,12 +126,12 @@ def generate_repeats(): ( {'max_length': 3}, {1, 2, 3, 4}, - Err('Set should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Set should have at most 3 items after validation, not more [type=too_long,'), ), ( {'max_length': 3}, [1, 2, 3, 4], - Err('Set should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Set should have at most 3 items after validation, not more [type=too_long,'), ), ({'max_length': 3, 'items_schema': {'type': 'int'}}, {1, 2, 3, 4}, Err('type=too_long,')), ({'max_length': 3, 'items_schema': {'type': 'int'}}, [1, 2, 3, 4], Err('type=too_long,')), @@ -141,7 +141,7 @@ def generate_repeats(): ( {'max_length': 3}, infinite_generator(), - Err('Set should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Set should have at most 3 items after validation, not more [type=too_long,'), ), ], ids=repr, diff --git a/tests/validators/test_tuple.py b/tests/validators/test_tuple.py index b7f3e9720..023a2b48e 100644 --- a/tests/validators/test_tuple.py +++ b/tests/validators/test_tuple.py @@ -107,13 +107,13 @@ def test_tuple_strict_fails_without_tuple(wrong_coll_type: Type[Any], mode, item ), ( {'max_length': 3}, - [1, 2, 3, 4], - Err('Tuple should have at most 3 items after validation, not 4 [type=too_long,'), + [1, 2, 3, 4, 5], + Err('Tuple should have at most 3 items after validation, not 5 [type=too_long,'), ), ( {'max_length': 3}, infinite_generator(), - Err('Tuple should have at most 3 items after validation, not 4 [type=too_long,'), + Err('Tuple should have at most 3 items after validation, not more [type=too_long,'), ), ], ids=repr, From 1a966d55581e1a1379cfe6274da6323c9786aefb Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:24:48 +0100 Subject: [PATCH 054/550] bump version to 10.0.1 (#991) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cfcbf6591..a2db9901f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.10.0" +version = "2.10.1" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index f28ec2944..7101a7019 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.10.0" +version = "2.10.1" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From a8fb1e3f46598498b2f01d2a5949ae501739717f Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Thu, 28 Sep 2023 09:45:34 +0100 Subject: [PATCH 055/550] Replace definitions `Vec` with `OnceLock` slots (#992) --- src/definitions.rs | 268 +++++++++++++----- src/py_gc.rs | 8 + src/serializers/extra.rs | 9 - src/serializers/mod.rs | 9 +- src/serializers/shared.rs | 4 +- src/serializers/type_serializers/dataclass.rs | 4 +- .../type_serializers/definitions.rs | 30 +- src/serializers/type_serializers/model.rs | 4 +- src/serializers/type_serializers/nullable.rs | 6 +- src/serializers/type_serializers/union.rs | 14 +- .../type_serializers/with_default.rs | 6 +- src/validators/any.rs | 8 +- src/validators/arguments.rs | 26 +- src/validators/bool.rs | 8 +- src/validators/bytes.rs | 16 +- src/validators/call.rs | 21 +- src/validators/callable.rs | 8 +- src/validators/chain.rs | 16 +- src/validators/custom_error.rs | 14 +- src/validators/dataclass.rs | 32 +-- src/validators/date.rs | 8 +- src/validators/datetime.rs | 8 +- src/validators/decimal.rs | 8 +- src/validators/definitions.rs | 100 +++---- src/validators/dict.rs | 17 +- src/validators/float.rs | 16 +- src/validators/frozenset.rs | 14 +- src/validators/function.rs | 53 ++-- src/validators/generator.rs | 39 ++- src/validators/int.rs | 16 +- src/validators/is_instance.rs | 8 +- src/validators/is_subclass.rs | 8 +- src/validators/json.rs | 16 +- src/validators/json_or_python.rs | 17 +- src/validators/lax_or_strict.rs | 16 +- src/validators/list.rs | 49 ++-- src/validators/literal.rs | 14 +- src/validators/mod.rs | 56 ++-- src/validators/model.rs | 14 +- src/validators/model_fields.rs | 22 +- src/validators/none.rs | 8 +- src/validators/nullable.rs | 14 +- src/validators/set.rs | 14 +- src/validators/string.rs | 18 +- src/validators/time.rs | 8 +- src/validators/timedelta.rs | 8 +- src/validators/tuple.rs | 44 +-- src/validators/typed_dict.rs | 22 +- src/validators/union.rs | 89 +++--- src/validators/url.rs | 16 +- src/validators/uuid.rs | 8 +- src/validators/validation_state.rs | 18 +- src/validators/with_default.rs | 14 +- .../validators/test_definitions_recursive.py | 192 ++++++++++++- 54 files changed, 793 insertions(+), 690 deletions(-) diff --git a/src/definitions.rs b/src/definitions.rs index 0d01fd2ae..1eb813015 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -3,16 +3,20 @@ /// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar. /// We use DefinitionsBuilder to collect the references / definitions into a single vector /// and then get a definition from a reference using an integer id (just for performance of not using a HashMap) -use std::collections::hash_map::Entry; +use std::{ + collections::hash_map::Entry, + fmt::Debug, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, OnceLock, + }, +}; -use pyo3::prelude::*; +use pyo3::{prelude::*, PyTraverseError, PyVisit}; use ahash::AHashMap; -use crate::build_tools::py_schema_err; - -// An integer id for the reference -pub type ReferenceId = usize; +use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; /// Definitions are validators and serializers that are /// shared by reference. @@ -24,91 +28,227 @@ pub type ReferenceId = usize; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -pub type Definitions = [T]; +#[derive(Clone)] +pub struct Definitions(AHashMap, Definition>); -#[derive(Clone, Debug)] -struct Definition { - pub id: ReferenceId, - pub value: Option, +impl Definitions { + pub fn values(&self) -> impl Iterator> { + self.0.values() + } +} + +/// Internal type which contains a definition to be filled +pub struct Definition(Arc>); + +impl Definition { + pub fn get(&self) -> Option<&T> { + self.0.value.get() + } +} + +struct DefinitionInner { + value: OnceLock, + name: LazyName, +} + +/// Reference to a definition. +pub struct DefinitionRef { + name: Arc, + value: Definition, +} + +// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone) +impl Clone for DefinitionRef { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + value: self.value.clone(), + } + } +} + +impl DefinitionRef { + pub fn id(&self) -> usize { + Arc::as_ptr(&self.value.0) as usize + } + + pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str { + match self.value.0.value.get() { + Some(value) => self.value.0.name.get_or_init(|| init(value)), + None => "...", + } + } + + pub fn get(&self) -> Option<&T> { + self.value.0.value.get() + } +} + +impl Debug for DefinitionRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To avoid possible infinite recursion from recursive definitions, + // a DefinitionRef just displays debug as its name + self.name.fmt(f) + } +} + +impl Debug for Definitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Formatted as a list for backwards compatibility; in principle + // this could be formatted as a map. Maybe change in a future + // minor release of pydantic. + write![f, "["]?; + let mut first = true; + for def in self.0.values() { + write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?; + first = false; + } + write![f, "]"]?; + Ok(()) + } +} + +impl Clone for Definition { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Debug for Definition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.value.get() { + Some(value) => value.fmt(f), + None => "...".fmt(f), + } + } +} + +impl PyGcTraverse for DefinitionRef { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(value) = self.value.0.value.get() { + value.py_gc_traverse(visit)?; + } + Ok(()) + } +} + +impl PyGcTraverse for Definitions { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + for value in self.0.values() { + if let Some(value) = value.0.value.get() { + value.py_gc_traverse(visit)?; + } + } + Ok(()) + } } #[derive(Clone, Debug)] pub struct DefinitionsBuilder { - definitions: AHashMap>, + definitions: Definitions, } -impl DefinitionsBuilder { +impl DefinitionsBuilder { pub fn new() -> Self { Self { - definitions: AHashMap::new(), + definitions: Definitions(AHashMap::new()), } } /// Get a ReferenceId for the given reference string. - // This ReferenceId can later be used to retrieve a definition - pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId { - let next_id = self.definitions.len(); + pub fn get_definition(&mut self, reference: &str) -> DefinitionRef { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - match self.definitions.entry(reference.to_string()) { - Entry::Occupied(entry) => entry.get().id, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: None, - }); - next_id - } + let name = Arc::new(reference.to_string()); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::new(), + name: LazyName::new(), + }))), + }; + DefinitionRef { + name, + value: value.clone(), } } /// Add a definition, returning the ReferenceId that maps to it - pub fn add_definition(&mut self, reference: String, value: T) -> PyResult { - let next_id = self.definitions.len(); - match self.definitions.entry(reference.clone()) { - Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) { - Some(_) => py_schema_err!("Duplicate ref: `{}`", reference), - None => Ok(entry.get().id), - }, - Entry::Vacant(entry) => { - entry.insert(Definition { - id: next_id, - value: Some(value), - }); - Ok(next_id) + pub fn add_definition(&mut self, reference: String, value: T) -> PyResult> { + let name = Arc::new(reference); + let value = match self.definitions.0.entry(name.clone()) { + Entry::Occupied(entry) => { + let definition = entry.into_mut(); + match definition.0.value.set(value) { + Ok(()) => definition.clone(), + Err(_) => return py_schema_err!("Duplicate ref: `{}`", name), + } + } + Entry::Vacant(entry) => entry + .insert(Definition(Arc::new(DefinitionInner { + value: OnceLock::from(value), + name: LazyName::new(), + }))) + .clone(), + }; + Ok(DefinitionRef { name, value }) + } + + /// Consume this Definitions into a vector of items, indexed by each items ReferenceId + pub fn finish(self) -> PyResult> { + for (reference, def) in &self.definitions.0 { + if def.0.value.get().is_none() { + return py_schema_err!("Definitions error: definition `{}` was never filled", reference); } } + Ok(self.definitions) } +} - /// Retrieve an item definition using a ReferenceId - /// If the definition doesn't yet exist (as happens in recursive types) then we create it - /// At the end (in finish()) we check that there are no undefined definitions - pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> { - let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) { - Some(v) => v, - None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id), - }; - match def.value.as_ref() { - Some(v) => Ok(v), - None => py_schema_err!( - "Definitions error: attempted to use `{}` before it was filled", - reference - ), +struct LazyName { + initialized: OnceLock, + in_recursion: AtomicBool, +} + +impl LazyName { + fn new() -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), } } - /// Consume this Definitions into a vector of items, indexed by each items ReferenceId - pub fn finish(self) -> PyResult> { - // We need to create a vec of defs according to the order in their ids - let mut defs: Vec<(usize, T)> = Vec::new(); - for (reference, def) in self.definitions { - match def.value { - None => return py_schema_err!("Definitions error: definition {} was never filled", reference), - Some(v) => defs.push((def.id, v)), - } + /// Gets the validator name, returning the default in the case of recursion loops + fn get_or_init(&self, init: impl FnOnce() -> String) -> &str { + if let Some(s) = self.initialized.get() { + return s.as_str(); + } + + if self + .in_recursion + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return "..."; + } + let result = self.initialized.get_or_init(init).as_str(); + self.in_recursion.store(false, Ordering::SeqCst); + result + } +} + +impl Debug for LazyName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.initialized.get().map_or("...", String::as_str).fmt(f) + } +} + +impl Clone for LazyName { + fn clone(&self) -> Self { + Self { + initialized: OnceLock::new(), + in_recursion: AtomicBool::new(false), } - defs.sort_by_key(|(id, _)| *id); - Ok(defs.into_iter().map(|(_, v)| v).collect()) } } diff --git a/src/py_gc.rs b/src/py_gc.rs index 02df02e13..8af285afb 100644 --- a/src/py_gc.rs +++ b/src/py_gc.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ahash::AHashMap; use enum_dispatch::enum_dispatch; use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit}; @@ -35,6 +37,12 @@ impl PyGcTraverse for AHashMap { } } +impl PyGcTraverse for Arc { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + T::py_gc_traverse(self, visit) + } +} + impl PyGcTraverse for Box { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { T::py_gc_traverse(self, visit) diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 9972a82c4..65c5a1ba9 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -10,8 +10,6 @@ use serde::ser::Error; use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; -use super::shared::CombinedSerializer; -use crate::definitions::Definitions; use crate::recursion_guard::RecursionGuard; /// this is ugly, would be much better if extra could be stored in `SerializationState` @@ -48,7 +46,6 @@ impl SerializationState { Extra::new( py, mode, - &[], by_alias, &self.warnings, false, @@ -72,7 +69,6 @@ impl SerializationState { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct Extra<'a> { pub mode: &'a SerMode, - pub definitions: &'a Definitions, pub ob_type_lookup: &'a ObTypeLookup, pub warnings: &'a CollectWarnings, pub by_alias: bool, @@ -98,7 +94,6 @@ impl<'a> Extra<'a> { pub fn new( py: Python<'a>, mode: &'a SerMode, - definitions: &'a Definitions, by_alias: bool, warnings: &'a CollectWarnings, exclude_unset: bool, @@ -112,7 +107,6 @@ impl<'a> Extra<'a> { ) -> Self { Self { mode, - definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings, by_alias, @@ -156,7 +150,6 @@ impl SerCheck { #[cfg_attr(debug_assertions, derive(Debug))] pub(crate) struct ExtraOwned { mode: SerMode, - definitions: Vec, warnings: CollectWarnings, by_alias: bool, exclude_unset: bool, @@ -176,7 +169,6 @@ impl ExtraOwned { pub fn new(extra: &Extra) -> Self { Self { mode: extra.mode.clone(), - definitions: extra.definitions.to_vec(), warnings: extra.warnings.clone(), by_alias: extra.by_alias, exclude_unset: extra.exclude_unset, @@ -196,7 +188,6 @@ impl ExtraOwned { pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> { Extra { mode: &self.mode, - definitions: &self.definitions, ob_type_lookup: ObTypeLookup::cached(py), warnings: &self.warnings, by_alias: self.by_alias, diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 6dbc076fe..72028346b 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use pyo3::{PyTraverseError, PyVisit}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; use config::SerializationConfig; @@ -30,7 +30,7 @@ mod type_serializers; #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, - definitions: Vec, + definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, } @@ -54,7 +54,6 @@ impl SchemaSerializer { Extra::new( py, mode, - &self.definitions, by_alias, warnings, exclude_unset, @@ -184,9 +183,7 @@ impl SchemaSerializer { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.serializer.py_gc_traverse(&visit)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; - } + self.definitions.py_gc_traverse(&visit)?; Ok(()) } } diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index b9b0c1fe1..7c24ff6db 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -13,7 +13,7 @@ use serde_json::ser::PrettyFormatter; use crate::build_tools::py_schema_err; use crate::build_tools::py_schema_error_type; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; use crate::tools::{py_err, SchemaDict}; @@ -293,7 +293,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug { fn get_name(&self) -> &str; /// Used by union serializers to decide if it's worth trying again while allowing subclasses - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { false } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 124f962ad..787e267dd 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -6,7 +6,7 @@ use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{ @@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4614bbc56..cf92df244 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; -use crate::definitions::Definitions; +use crate::definitions::DefinitionRef; use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; @@ -41,7 +41,7 @@ impl BuildSerializer for DefinitionsSerializerBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefSerializer { - serializer_id: usize, + definition: DefinitionRef, } impl BuildSerializer for DefinitionRefSerializer { @@ -52,9 +52,9 @@ impl BuildSerializer for DefinitionRefSerializer { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - let serializer_id = definitions.get_reference_id(&schema_ref); - Ok(Self { serializer_id }.into()) + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; + let definition = definitions.get_definition(schema_ref); + Ok(Self { definition }.into()) } } @@ -68,10 +68,10 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let value_id = extra.rec_guard.add(value, self.serializer_id)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } @@ -87,10 +87,13 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?; - let comb_serializer = extra.definitions.get(self.serializer_id).unwrap(); + let comb_serializer = self.definition.get().unwrap(); + let value_id = extra + .rec_guard + .add(value, self.definition.id()) + .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.serializer_id); + extra.rec_guard.pop(value_id, self.definition.id()); r } @@ -98,8 +101,7 @@ impl TypeSerializer for DefinitionRefSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - let comb_serializer = definitions.get(self.serializer_id).unwrap(); - comb_serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.definition.get().unwrap().retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index c5b252fbf..8a2eeb4e1 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -13,7 +13,7 @@ use super::{ }; use crate::build_tools::py_schema_err; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::serializers::errors::PydanticSerializationUnexpectedValue; use crate::tools::SchemaDict; @@ -228,7 +228,7 @@ impl TypeSerializer for ModelSerializer { &self.name } - fn retry_with_lax_check(&self, _definitions: &Definitions) -> bool { + fn retry_with_lax_check(&self) -> bool { true } } diff --git a/src/serializers/type_serializers/nullable.rs b/src/serializers/type_serializers/nullable.rs index 837d6c5f1..23349ec81 100644 --- a/src/serializers/type_serializers/nullable.rs +++ b/src/serializers/type_serializers/nullable.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use super::{infer_json_key_known, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, TypeSerializer}; @@ -75,7 +75,7 @@ impl TypeSerializer for NullableSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } } diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 70818959e..788620408 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -4,7 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; use crate::build_tools::py_schema_err; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::PydanticSerializationUnexpectedValue; @@ -87,7 +87,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -116,7 +116,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.json_key(key, &new_extra) { @@ -153,7 +153,7 @@ impl TypeSerializer for UnionSerializer { }, } } - if self.retry_with_lax_check(extra.definitions) { + if self.retry_with_lax_check() { new_extra.check = SerCheck::Lax; for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { @@ -174,10 +174,8 @@ impl TypeSerializer for UnionSerializer { &self.name } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.choices - .iter() - .any(|choice| choice.retry_with_lax_check(definitions)) + fn retry_with_lax_check(&self) -> bool { + self.choices.iter().any(CombinedSerializer::retry_with_lax_check) } } diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index 148c05052..d20c273a1 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -4,7 +4,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::definitions::{Definitions, DefinitionsBuilder}; +use crate::definitions::DefinitionsBuilder; use crate::tools::SchemaDict; use crate::validators::DefaultType; @@ -67,8 +67,8 @@ impl TypeSerializer for WithDefaultSerializer { Self::EXPECTED_TYPE } - fn retry_with_lax_check(&self, definitions: &Definitions) -> bool { - self.serializer.retry_with_lax_check(definitions) + fn retry_with_lax_check(&self) -> bool { + self.serializer.retry_with_lax_check() } fn get_default(&self, py: Python) -> PyResult> { diff --git a/src/validators/any.rs b/src/validators/any.rs index eddde1725..625eb4adf 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -34,11 +34,7 @@ impl Validator for AnyValidator { Ok(input.to_object(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -46,7 +42,7 @@ impl Validator for AnyValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 2c0fe4a0a..aa0870043 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -15,7 +15,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Parameter { positional: bool, name: String, @@ -24,7 +24,7 @@ struct Parameter { validator: CombinedValidator, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ArgumentsValidator { parameters: Vec, positional_params_count: usize, @@ -332,29 +332,25 @@ impl Validator for ArgumentsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.parameters .iter() - .any(|p| p.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|p| p.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { self.parameters - .iter_mut() - .try_for_each(|parameter| parameter.validator.complete(definitions))?; - if let Some(v) = &mut self.var_args_validator { - v.complete(definitions)?; + .iter() + .try_for_each(|parameter| parameter.validator.complete())?; + if let Some(v) = &self.var_args_validator { + v.complete()?; } - if let Some(v) = &mut self.var_kwargs_validator { - v.complete(definitions)?; + if let Some(v) = &self.var_kwargs_validator { + v.complete()?; }; Ok(()) } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index d87c1c1d7..3a38cf3e5 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -42,11 +42,7 @@ impl Validator for BoolValidator { Ok(input.validate_bool(strict)?.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -54,7 +50,7 @@ impl Validator for BoolValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 2084f916e..0f662af1c 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -50,11 +50,7 @@ impl Validator for BytesValidator { Ok(either_bytes.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -62,7 +58,7 @@ impl Validator for BytesValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -112,11 +108,7 @@ impl Validator for BytesConstrainedValidator { Ok(either_bytes.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -124,7 +116,7 @@ impl Validator for BytesConstrainedValidator { "constrained-bytes" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/call.rs b/src/validators/call.rs index 24c7f4111..3f6eb4d35 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -11,7 +11,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CallValidator { function: PyObject, arguments_validator: Box, @@ -98,28 +98,23 @@ impl Validator for CallValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(return_validator) = &self.return_validator { - if return_validator.different_strict_behavior(definitions, ultra_strict) { + if return_validator.different_strict_behavior(ultra_strict) { return true; } } - self.arguments_validator - .different_strict_behavior(definitions, ultra_strict) + self.arguments_validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.arguments_validator.complete(definitions)?; - match &mut self.return_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.arguments_validator.complete()?; + match &self.return_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index 9b565e3eb..83eb37cbe 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -36,11 +36,7 @@ impl Validator for CallableValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -48,7 +44,7 @@ impl Validator for CallableValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index 001947d1f..c0f356fa0 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::validation_state::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ChainValidator { steps: Vec, name: String, @@ -83,21 +83,15 @@ impl Validator for ChainValidator { steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.steps - .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.steps.iter().any(|v| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.steps.iter_mut().try_for_each(|v| v.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.steps.iter().try_for_each(CombinedValidator::complete) } } diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 1e8258090..0d9931c62 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -57,7 +57,7 @@ impl CustomError { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CustomErrorValidator { validator: Box, custom_error: CustomError, @@ -99,19 +99,15 @@ impl Validator for CustomErrorValidator { .map_err(|_| self.custom_error.as_val_error(input)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 117596b9f..dff7735a3 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -19,7 +19,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { kw_only: bool, name: String, @@ -30,7 +30,7 @@ struct Field { frozen: bool, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassArgsValidator { fields: Vec, positional_count: usize, @@ -426,28 +426,22 @@ impl Validator for DataclassArgsValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.validator_name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|field| field.validator.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|field| field.validator.complete()) } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DataclassValidator { strict: bool, validator: Box, @@ -588,13 +582,9 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -604,8 +594,8 @@ impl Validator for DataclassValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/date.rs b/src/validators/date.rs index a771a5045..3549f66f0 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -96,11 +96,7 @@ impl Validator for DateValidator { Ok(date.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -108,7 +104,7 @@ impl Validator for DateValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 7596b7aca..baf4ca467 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -125,11 +125,7 @@ impl Validator for DateTimeValidator { Ok(datetime.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -137,7 +133,7 @@ impl Validator for DateTimeValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 2564e096a..211befe07 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -230,11 +230,7 @@ impl Validator for DecimalValidator { Ok(decimal.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -242,7 +238,7 @@ impl Validator for DecimalValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 3a35fce4c..16aea8cd4 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -1,7 +1,12 @@ +use std::cell::RefCell; + +use ahash::HashSet; +use ahash::HashSetExt; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; +use crate::definitions::DefinitionRef; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; @@ -39,17 +44,12 @@ impl BuildValidator for DefinitionsValidatorBuilder { #[derive(Debug, Clone)] pub struct DefinitionRefValidator { - validator_id: usize, - inner_name: String, - // we have to record the answers to `Question`s as we can't access the validator when `ask()` is called + definition: DefinitionRef, } impl DefinitionRefValidator { - pub fn new(validator_id: usize) -> Self { - Self { - validator_id, - inner_name: "...".to_string(), - } + pub fn new(definition: DefinitionRef) -> Self { + Self { definition } } } @@ -61,15 +61,10 @@ impl BuildValidator for DefinitionRefValidator { _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - - let validator_id = definitions.get_reference_id(&schema_ref); + let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; - Ok(Self { - validator_id, - inner_name: "...".to_string(), - } - .into()) + let definition = definitions.get_definition(schema_ref); + Ok(Self::new(definition).into()) } } @@ -82,21 +77,22 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } - let output = validate(self.validator_id, py, input, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate(py, input, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate(self.validator_id, py, input, state) + validator.validate(py, input, state) } } @@ -108,69 +104,51 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + let validator = self.definition.get().unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.validator_id) { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } else { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } - let output = validate_assignment(self.validator_id, py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.validator_id); + let output = validator.validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } } else { - validate_assignment(self.validator_id, py, obj, field_name, field_value, state) + validator.validate_assignment(py, obj, field_name, field_value, state) } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - if let Some(definitions) = definitions { - // have to unwrap here, because we can't return an error from this function, should be okay - let validator = definitions.get_definition(self.validator_id).unwrap(); - validator.different_strict_behavior(None, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + thread_local! { + static RECURSION_SET: RefCell>> = RefCell::new(None); + } + + let id = self as *const _ as usize; + // have to unwrap here, because we can't return an error from this function, should be okay + let validator: &CombinedValidator = self.definition.get().unwrap(); + if RECURSION_SET.with( + |set: &RefCell>>| { + set.borrow_mut().get_or_insert_with(HashSet::new).insert(id) + }, + ) { + let different_strict_behavior = validator.different_strict_behavior(ultra_strict); + RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).remove(&id)); + different_strict_behavior } else { false } } fn get_name(&self) -> &str { - &self.inner_name + self.definition.get_or_init_name(|v| v.get_name().into()) } - /// don't need to call complete on the inner validator here, complete_validators takes care of that. - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - let validator = definitions.get_definition(self.validator_id)?; - self.inner_name = validator.get_name().to_string(); + fn complete(&self) -> PyResult<()> { Ok(()) } } - -fn validate<'data>( - validator_id: usize, - py: Python<'data>, - input: &'data impl Input<'data>, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate(py, input, state) -} - -#[allow(clippy::too_many_arguments)] -fn validate_assignment<'data>( - validator_id: usize, - py: Python<'data>, - obj: &'data PyAny, - field_name: &'data str, - field_value: &'data PyAny, - state: &mut ValidationState, -) -> ValResult<'data, PyObject> { - let validator = state.definitions.get(validator_id).unwrap(); - validator.validate_assignment(py, obj, field_name, field_value, state) -} diff --git a/src/validators/dict.rs b/src/validators/dict.rs index dc8f03937..250145290 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -16,7 +16,7 @@ use super::any::AnyValidator; use super::list::length_check; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DictValidator { strict: bool, key_validator: Box, @@ -92,14 +92,9 @@ impl Validator for DictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.key_validator.different_strict_behavior(definitions, true) - || self.value_validator.different_strict_behavior(definitions, true) + self.key_validator.different_strict_behavior(true) || self.value_validator.different_strict_behavior(true) } else { true } @@ -109,9 +104,9 @@ impl Validator for DictValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.key_validator.complete(definitions)?; - self.value_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.key_validator.complete()?; + self.value_validator.complete() } } diff --git a/src/validators/float.rs b/src/validators/float.rs index f0eb41750..2e9434d9f 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -76,11 +76,7 @@ impl Validator for FloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -88,7 +84,7 @@ impl Validator for FloatValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -179,11 +175,7 @@ impl Validator for ConstrainedFloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { true } @@ -191,7 +183,7 @@ impl Validator for ConstrainedFloatValidator { "constrained-float" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index ad7708324..4b4cdcb6f 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -10,7 +10,7 @@ use super::set::set_build; use super::validation_state::ValidationState; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FrozenSetValidator { strict: bool, item_validator: Box, @@ -48,13 +48,9 @@ impl Validator for FrozenSetValidator { Ok(f_set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) + self.item_validator.different_strict_behavior(true) } else { true } @@ -64,7 +60,7 @@ impl Validator for FrozenSetValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.item_validator.complete() } } diff --git a/src/validators/function.rs b/src/validators/function.rs index be0d6374f..adb143696 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; @@ -111,14 +113,9 @@ macro_rules! impl_validator { self._validate(validate, py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator - .different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -128,14 +125,14 @@ macro_rules! impl_validator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } }; } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionBeforeValidator { validator: Box, func: PyObject, @@ -168,7 +165,7 @@ impl FunctionBeforeValidator { impl_validator!(FunctionBeforeValidator); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionAfterValidator { validator: Box, func: PyObject, @@ -255,11 +252,7 @@ impl Validator for FunctionPlainValidator { r.map_err(|e| convert_err(py, e, input)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { // best guess, should we change this? !ultra_strict } @@ -268,14 +261,14 @@ impl Validator for FunctionPlainValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FunctionWrapValidator { - validator: Box, + validator: Arc, func: PyObject, config: PyObject, name: String, @@ -299,7 +292,7 @@ impl BuildValidator for FunctionWrapValidator { let hide_input_in_errors: bool = config.get_as(intern!(py, "hide_input_in_errors"))?.unwrap_or(false); let validation_error_cause: bool = config.get_as(intern!(py, "validation_error_cause"))?.unwrap_or(false); Ok(Self { - validator: Box::new(validator), + validator: Arc::new(validator), func: function_info.function.clone(), config: match config { Some(c) => c.into(), @@ -350,7 +343,7 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "ValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -376,7 +369,7 @@ impl Validator for FunctionWrapValidator { validator: InternalValidator::new( py, "AssignmentValidatorCallable", - &self.validator, + self.validator.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -387,13 +380,9 @@ impl Validator for FunctionWrapValidator { self._validate(Py::new(py, handler)?.into_ref(py), py, obj, state) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -403,13 +392,13 @@ impl Validator for FunctionWrapValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorCallable { validator: InternalValidator, } @@ -441,7 +430,7 @@ impl ValidatorCallable { } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct AssignmentValidatorCallable { updated_field_name: String, updated_field_value: Py, diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 0cff7e28e..111b5c101 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -1,4 +1,5 @@ use std::fmt; +use std::sync::Arc; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -14,7 +15,7 @@ use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputT #[derive(Debug, Clone)] pub struct GeneratorValidator { - item_validator: Option>, + item_validator: Option>, min_length: Option, max_length: Option, name: String, @@ -30,7 +31,7 @@ impl BuildValidator for GeneratorValidator { config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new); let name = match item_validator { Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()), None => format!("{}[any]", Self::EXPECTED_TYPE), @@ -67,7 +68,7 @@ impl Validator for GeneratorValidator { InternalValidator::new( py, "ValidatorIterator", - v, + v.clone(), state, self.hide_input_in_errors, self.validation_error_cause, @@ -85,13 +86,9 @@ impl Validator for GeneratorValidator { Ok(v_iterator.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(ref v) = self.item_validator { - v.different_strict_behavior(definitions, ultra_strict) + v.different_strict_behavior(ultra_strict) } else { false } @@ -101,16 +98,16 @@ impl Validator for GeneratorValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.item_validator { + Some(v) => v.complete(), None => Ok(()), } } } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] struct ValidatorIterator { iterator: GenericIterator, validator: Option, @@ -217,13 +214,11 @@ impl ValidatorIterator { } } -/// Cloneable validator wrapper for use in generators in functions, this can be passed back to python +/// Owned validator wrapper for use in generators in functions, this can be passed back to python /// mid-validation -#[derive(Clone)] pub struct InternalValidator { name: String, - validator: CombinedValidator, - definitions: Vec, + validator: Arc, // TODO, do we need data? data: Option>, strict: Option, @@ -246,7 +241,7 @@ impl InternalValidator { pub fn new( py: Python, name: &str, - validator: &CombinedValidator, + validator: Arc, state: &ValidationState, hide_input_in_errors: bool, validation_error_cause: bool, @@ -254,8 +249,7 @@ impl InternalValidator { let extra = state.extra(); Self { name: name.to_string(), - validator: validator.clone(), - definitions: state.definitions.to_vec(), + validator, data: extra.data.map(|d| d.into_py(py)), strict: extra.strict, from_attributes: extra.from_attributes, @@ -285,7 +279,7 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); + let mut state = ValidationState::new(extra, &mut self.recursion_guard); self.validator .validate_assignment(py, model, field_name, field_value, &mut state) .map_err(|e| { @@ -316,7 +310,7 @@ impl InternalValidator { context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; - let mut state = ValidationState::new(extra, &self.definitions, &mut self.recursion_guard); + let mut state = ValidationState::new(extra, &mut self.recursion_guard); self.validator.validate(py, input, &mut state).map_err(|e| { ValidationError::from_val_error( py, @@ -333,7 +327,6 @@ impl InternalValidator { impl_py_gc_traverse!(InternalValidator { validator, - definitions, data, context, self_instance diff --git a/src/validators/int.rs b/src/validators/int.rs index 3fba2199d..0903e1998 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -54,11 +54,7 @@ impl Validator for IntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -66,7 +62,7 @@ impl Validator for IntValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -151,11 +147,7 @@ impl Validator for ConstrainedIntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -163,7 +155,7 @@ impl Validator for ConstrainedIntValidator { "constrained-int" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 78705482c..e64d0717c 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -83,11 +83,7 @@ impl Validator for IsInstanceValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -95,7 +91,7 @@ impl Validator for IsInstanceValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index d0f5a6cfe..0866fa1e7 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -62,11 +62,7 @@ impl Validator for IsSubclassValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -74,7 +70,7 @@ impl Validator for IsSubclassValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/json.rs b/src/validators/json.rs index 5eda007be..fd832f874 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -9,7 +9,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonValidator { validator: Option>, name: String, @@ -61,13 +61,9 @@ impl Validator for JsonValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if let Some(ref v) = self.validator { - v.different_strict_behavior(definitions, ultra_strict) + v.different_strict_behavior(ultra_strict) } else { false } @@ -77,9 +73,9 @@ impl Validator for JsonValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 828532fe5..cd952bed1 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -11,7 +11,7 @@ use super::InputType; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct JsonOrPython { json: Box, python: Box, @@ -63,21 +63,16 @@ impl Validator for JsonOrPython { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.json.different_strict_behavior(definitions, ultra_strict) - || self.python.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.json.different_strict_behavior(ultra_strict) || self.python.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.json.complete(definitions)?; - self.python.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.json.complete()?; + self.python.complete() } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 9681cf689..b5cec61be 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct LaxOrStrictValidator { strict: bool, lax_validator: Box, @@ -68,13 +68,9 @@ impl Validator for LaxOrStrictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.strict_validator.different_strict_behavior(definitions, true) + self.strict_validator.different_strict_behavior(true) } else { true } @@ -84,8 +80,8 @@ impl Validator for LaxOrStrictValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lax_validator.complete(definitions)?; - self.strict_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.lax_validator.complete()?; + self.strict_validator.complete() } } diff --git a/src/validators/list.rs b/src/validators/list.rs index ffd7a118e..8e931657f 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -1,3 +1,5 @@ +use std::sync::OnceLock; + use pyo3::prelude::*; use pyo3::types::PyDict; @@ -7,26 +9,26 @@ use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ListValidator { strict: bool, item_validator: Option>, min_length: Option, max_length: Option, - name: String, + name: OnceLock, } pub fn get_items_schema( schema: &PyDict, config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, -) -> PyResult>> { +) -> PyResult> { match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { Some(d) => { let validator = build_validator(d, config, definitions)?; match validator { CombinedValidator::Any(_) => Ok(None), - _ => Ok(Some(Box::new(validator))), + _ => Ok(Some(validator)), } } None => Ok(None), @@ -98,15 +100,13 @@ impl BuildValidator for ListValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; - let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); - let name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); Ok(Self { strict: crate::build_tools::is_strict(schema, config)?, item_validator, min_length: schema.get_as(pyo3::intern!(py, "min_length"))?, max_length: schema.get_as(pyo3::intern!(py, "max_length"))?, - name, + name: OnceLock::new(), } .into()) } @@ -138,14 +138,10 @@ impl Validator for ListValidator { Ok(output.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), + Some(ref v) => v.different_strict_behavior(true), None => false, } } else { @@ -154,14 +150,27 @@ impl Validator for ListValidator { } fn get_name(&self) -> &str { - &self.name + // The logic here is a little janky, it's done to try to cache the formatted name + // while also trying to render definitions correctly when possible. + // + // Probably an opportunity for a future refactor + match self.name.get() { + Some(s) => s.as_str(), + None => { + let name = self.item_validator.as_ref().map_or("any", |v| v.get_name()); + if name == "..." { + // when inner name is not initialized yet, don't cache it here + "list[...]" + } else { + self.name.get_or_init(|| format!("list[{name}]")).as_str() + } + } + } } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - if let Some(ref mut v) = self.item_validator { - v.complete(definitions)?; - let inner_name = v.get_name(); - self.name = format!("{}[{inner_name}]", Self::EXPECTED_TYPE); + fn complete(&self) -> PyResult<()> { + if let Some(v) = &self.item_validator { + v.complete()?; } Ok(()) } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index de394affb..25cb94bd9 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -22,7 +22,7 @@ struct BoolLiteral { } #[derive(Debug, Clone)] -pub struct LiteralLookup { +pub struct LiteralLookup { // Specialized lookups for ints, bools and strings because they // (1) are easy to convert between Rust and Python // (2) hashing them in Rust is very fast @@ -35,7 +35,7 @@ pub struct LiteralLookup { pub values: Vec, } -impl LiteralLookup { +impl LiteralLookup { pub fn new<'py>(py: Python<'py>, expected: impl Iterator) -> PyResult { let mut expected_int = AHashMap::new(); let mut expected_str: AHashMap = AHashMap::new(); @@ -135,7 +135,7 @@ impl LiteralLookup { } } -impl PyGcTraverse for LiteralLookup { +impl PyGcTraverse for LiteralLookup { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { self.expected_py.py_gc_traverse(visit)?; self.values.py_gc_traverse(visit)?; @@ -198,11 +198,7 @@ impl Validator for LiteralValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -210,7 +206,7 @@ impl Validator for LiteralValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 4ee677663..42aad2001 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -9,7 +9,7 @@ use pyo3::types::{PyAny, PyDict, PyTuple, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type, SchemaError}; -use crate::definitions::DefinitionsBuilder; +use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; @@ -98,10 +98,10 @@ impl PySome { } #[pyclass(module = "pydantic_core._pydantic_core")] -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SchemaValidator { validator: CombinedValidator, - definitions: Vec, + definitions: Definitions, schema: PyObject, #[pyo3(get)] title: PyObject, @@ -115,11 +115,11 @@ impl SchemaValidator { pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = build_validator(schema, config, &mut definitions_builder)?; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; + let validator = build_validator(schema, config, &mut definitions_builder)?; + let definitions = definitions_builder.finish()?; + validator.complete()?; + for val in definitions.values() { + val.get().unwrap().complete()?; } let config_title = match config { Some(c) => c.get_item("title"), @@ -141,9 +141,10 @@ impl SchemaValidator { }) } - pub fn __reduce__(&self, py: Python) -> PyResult { - let args = (self.schema.as_ref(py),); - let cls = Py::new(py, self.clone())?.getattr(py, "__class__")?; + pub fn __reduce__(slf: &PyCell) -> PyResult { + let py = slf.py(); + let args = (slf.try_borrow()?.schema.to_object(py),); + let cls = slf.getattr("__class__")?; Ok((cls, args).into_py(py)) } @@ -266,7 +267,7 @@ impl SchemaValidator { }; let guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, guard); + let mut state = ValidationState::new(extra, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) @@ -284,7 +285,7 @@ impl SchemaValidator { self_instance: None, }; let recursion_guard = &mut RecursionGuard::default(); - let mut state = ValidationState::new(extra, &self.definitions, recursion_guard); + let mut state = ValidationState::new(extra, recursion_guard); let r = self.validator.default_value(py, None::, &mut state); match r { Ok(maybe_default) => match maybe_default { @@ -307,9 +308,6 @@ impl SchemaValidator { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.validator.py_gc_traverse(&visit)?; visit.call(&self.schema)?; - for slot in &self.definitions { - slot.py_gc_traverse(&visit)?; - } Ok(()) } } @@ -332,7 +330,6 @@ impl SchemaValidator { { let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), - &self.definitions, recursion_guard, ); self.validator.validate(py, input, &mut state) @@ -371,7 +368,6 @@ impl<'py> SelfValidator<'py> { let mut recursion_guard = RecursionGuard::default(); let mut state = ValidationState::new( Extra::new(strict, None, None, None, InputType::Python), - &self.validator.definitions, &mut recursion_guard, ); match self.validator.validator.validate(py, schema, &mut state) { @@ -388,14 +384,14 @@ impl<'py> SelfValidator<'py> { let mut definitions_builder = DefinitionsBuilder::new(); - let mut validator = match build_validator(self_schema, None, &mut definitions_builder) { + let validator = match build_validator(self_schema, None, &mut definitions_builder) { Ok(v) => v, Err(err) => return py_schema_err!("Error building self-schema:\n {}", err), }; - validator.complete(&definitions_builder)?; - let mut definitions = definitions_builder.clone().finish()?; - for val in &mut definitions { - val.complete(&definitions_builder)?; + let definitions = definitions_builder.finish()?; + validator.complete()?; + for val in definitions.values() { + val.get().unwrap().complete()?; } Ok(SchemaValidator { validator, @@ -603,7 +599,7 @@ impl<'a> Extra<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug)] #[enum_dispatch(PyGcTraverse)] pub enum CombinedValidator { // typed dict e.g. heterogeneous dicts or simply a model @@ -699,7 +695,7 @@ pub enum CombinedValidator { /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, /// validators defined in `build_validator` also need `EXPECTED_TYPE` as a const, but that can't be part of the trait #[enum_dispatch(CombinedValidator)] -pub trait Validator: Send + Sync + Clone + Debug { +pub trait Validator: Send + Sync + Debug { /// Do the actual validation for this schema/type fn validate<'data>( &self, @@ -734,17 +730,13 @@ pub trait Validator: Send + Sync + Clone + Debug { /// whether the validator behaves differently in strict mode, and in ultra strict mode /// implementations should return true if any of their sub-validators return true - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool; + fn different_strict_behavior(&self, ultra_strict: bool) -> bool; /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; /// this method must be implemented for any validator which holds references to other validators, - /// it is used by `DefinitionRefValidator` to set its name - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()>; + /// it is used by `UnionValidator` to calculate strictness + fn complete(&self) -> PyResult<()>; } diff --git a/src/validators/model.rs b/src/validators/model.rs index 2ec7185a9..1459f56f5 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -50,7 +50,7 @@ impl Revalidate { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelValidator { revalidate: Revalidate, validator: Box, @@ -206,13 +206,9 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.validator.different_strict_behavior(definitions, ultra_strict) + self.validator.different_strict_behavior(ultra_strict) } else { true } @@ -222,8 +218,8 @@ impl Validator for ModelValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index f2654c33e..30d937cb8 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -20,7 +20,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuild use std::ops::ControlFlow; -#[derive(Debug, Clone)] +#[derive(Debug)] struct Field { name: String, lookup_key: LookupKey, @@ -31,7 +31,7 @@ struct Field { impl_py_gc_traverse!(Field { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelFieldsValidator { fields: Vec, model_name: String, @@ -415,26 +415,20 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|f| f.validator.complete())?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/none.rs b/src/validators/none.rs index 36be70acb..f241be9d8 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -36,11 +36,7 @@ impl Validator for NoneValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - _ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { false } @@ -48,7 +44,7 @@ impl Validator for NoneValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 4b408f206..7f4cf19fc 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -9,7 +9,7 @@ use crate::tools::SchemaDict; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct NullableValidator { validator: Box, name: String, @@ -45,19 +45,15 @@ impl Validator for NullableValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/src/validators/set.rs b/src/validators/set.rs index e5e2cecf3..626572139 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -8,7 +8,7 @@ use crate::tools::SchemaDict; use super::list::min_length_check; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SetValidator { strict: bool, item_validator: Box, @@ -70,13 +70,9 @@ impl Validator for SetValidator { Ok(set.into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - self.item_validator.different_strict_behavior(definitions, true) + self.item_validator.different_strict_behavior(true) } else { true } @@ -86,7 +82,7 @@ impl Validator for SetValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.item_validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.item_validator.complete() } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 6b646224d..4eab4602a 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct StrValidator { strict: bool, coerce_numbers_to_str: bool, @@ -51,11 +51,7 @@ impl Validator for StrValidator { Ok(either_str.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -63,7 +59,7 @@ impl Validator for StrValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -150,11 +146,7 @@ impl Validator for StrConstrainedValidator { Ok(py_string.into_py(py)) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -162,7 +154,7 @@ impl Validator for StrConstrainedValidator { "constrained-str" } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/time.rs b/src/validators/time.rs index 7bbd7e511..f5e2be7c7 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -78,11 +78,7 @@ impl Validator for TimeValidator { Ok(time.try_into_py(py)?) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -90,7 +86,7 @@ impl Validator for TimeValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 106d5a64a..21340f2f0 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -101,11 +101,7 @@ impl Validator for TimeDeltaValidator { Ok(py_timedelta.into()) } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -113,7 +109,7 @@ impl Validator for TimeDeltaValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 5c2c09bec..cfa239e3b 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -10,7 +10,7 @@ use crate::tools::SchemaDict; use super::list::{get_items_schema, min_length_check}; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TupleVariableValidator { strict: bool, item_validator: Option>, @@ -27,7 +27,7 @@ impl BuildValidator for TupleVariableValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?; + let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); let name = format!("tuple[{inner_name}, ...]"); Ok(Self { @@ -60,14 +60,10 @@ impl Validator for TupleVariableValidator { Ok(PyTuple::new(py, &output).into_py(py)) } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { match self.item_validator { - Some(ref v) => v.different_strict_behavior(definitions, true), + Some(ref v) => v.different_strict_behavior(true), None => false, } } else { @@ -79,15 +75,15 @@ impl Validator for TupleVariableValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - match self.item_validator { - Some(ref mut v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + match &self.item_validator { + Some(v) => v.complete(), None => Ok(()), } } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TuplePositionalValidator { strict: bool, items_validators: Vec, @@ -242,20 +238,12 @@ impl Validator for TuplePositionalValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { if ultra_strict { - if self - .items_validators - .iter() - .any(|v| v.different_strict_behavior(definitions, true)) - { + if self.items_validators.iter().any(|v| v.different_strict_behavior(true)) { true } else if let Some(ref v) = self.extras_validator { - v.different_strict_behavior(definitions, true) + v.different_strict_behavior(true) } else { false } @@ -268,12 +256,10 @@ impl Validator for TuplePositionalValidator { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.items_validators - .iter_mut() - .try_for_each(|v| v.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.items_validators.iter().try_for_each(CombinedValidator::complete)?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 56e4a8225..1a5b52dc6 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -20,7 +20,7 @@ use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; -#[derive(Debug, Clone)] +#[derive(Debug)] struct TypedDictField { name: String, lookup_key: LookupKey, @@ -31,7 +31,7 @@ struct TypedDictField { impl_py_gc_traverse!(TypedDictField { validator }); -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TypedDictValidator { fields: Vec, extra_behavior: ExtraBehavior, @@ -307,26 +307,20 @@ impl Validator for TypedDictValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.fields .iter() - .any(|f| f.validator.different_strict_behavior(definitions, ultra_strict)) + .any(|f| f.validator.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.fields - .iter_mut() - .try_for_each(|f| f.validator.complete(definitions))?; - match &mut self.extras_validator { - Some(v) => v.complete(definitions), + fn complete(&self) -> PyResult<()> { + self.fields.iter().try_for_each(|f| f.validator.complete())?; + match &self.extras_validator { + Some(v) => v.complete(), None => Ok(()), } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 4d3b0bd78..79a21e78d 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -1,5 +1,6 @@ use std::fmt::Write; use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; @@ -18,11 +19,11 @@ use super::custom_error::CustomError; use super::literal::LiteralLookup; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] enum UnionMode { Smart { - strict_required: bool, - ultra_strict_required: bool, + strict_required: AtomicBool, + ultra_strict_required: AtomicBool, }, LeftToRight, } @@ -31,8 +32,23 @@ impl UnionMode { // construct smart with some default values const fn default_smart() -> Self { Self::Smart { - strict_required: true, - ultra_strict_required: false, + strict_required: AtomicBool::new(true), + ultra_strict_required: AtomicBool::new(false), + } + } +} + +impl Clone for UnionMode { + fn clone(&self) -> Self { + match self { + Self::Smart { + strict_required, + ultra_strict_required, + } => Self::Smart { + strict_required: AtomicBool::new(strict_required.load(Ordering::SeqCst)), + ultra_strict_required: AtomicBool::new(ultra_strict_required.load(Ordering::SeqCst)), + }, + Self::LeftToRight => Self::LeftToRight, } } } @@ -49,7 +65,7 @@ impl FromStr for UnionMode { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct UnionValidator { mode: UnionMode, choices: Vec<(CombinedValidator, Option)>, @@ -216,44 +232,46 @@ impl Validator for UnionValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - match self.mode { + match &self.mode { UnionMode::Smart { strict_required, ultra_strict_required, - } => self.validate_smart(py, input, state, strict_required, ultra_strict_required), + } => self.validate_smart( + py, + input, + state, + strict_required.load(Ordering::SeqCst), + ultra_strict_required.load(Ordering::SeqCst), + ), UnionMode::LeftToRight => self.validate_left_to_right(py, input, state), } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.choices .iter() - .any(|(v, _)| v.different_strict_behavior(definitions, ultra_strict)) + .any(|(v, _)| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.choices.iter_mut().try_for_each(|(v, _)| v.complete(definitions))?; + fn complete(&self) -> PyResult<()> { + self.choices.iter().try_for_each(|(v, _)| v.complete())?; if let UnionMode::Smart { - ref mut strict_required, - ref mut ultra_strict_required, - } = self.mode + strict_required, + ultra_strict_required, + } = &self.mode { - *strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), false)); - *ultra_strict_required = self - .choices - .iter() - .any(|(v, _)| v.different_strict_behavior(Some(definitions), true)); + strict_required.store( + self.choices.iter().any(|(v, _)| v.different_strict_behavior(false)), + Ordering::SeqCst, + ); + ultra_strict_required.store( + self.choices.iter().any(|(v, _)| v.different_strict_behavior(true)), + Ordering::SeqCst, + ); } Ok(()) @@ -357,7 +375,7 @@ impl PyGcTraverse for Discriminator { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TaggedUnionValidator { discriminator: Discriminator, lookup: LiteralLookup, @@ -476,26 +494,19 @@ impl Validator for TaggedUnionValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { self.lookup .values .iter() - .any(|v| v.different_strict_behavior(definitions, ultra_strict)) + .any(|v| v.different_strict_behavior(ultra_strict)) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.lookup - .values - .iter_mut() - .try_for_each(|validator| validator.complete(definitions)) + fn complete(&self) -> PyResult<()> { + self.lookup.values.iter().try_for_each(CombinedValidator::complete) } } diff --git a/src/validators/url.rs b/src/validators/url.rs index 0afc76e59..4584ae652 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -92,11 +92,7 @@ impl Validator for UrlValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -104,7 +100,7 @@ impl Validator for UrlValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } @@ -232,11 +228,7 @@ impl Validator for MultiHostUrlValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -244,7 +236,7 @@ impl Validator for MultiHostUrlValidator { &self.name } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index ca924ce66..94d302438 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -122,11 +122,7 @@ impl Validator for UuidValidator { } } - fn different_strict_behavior( - &self, - _definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { !ultra_strict } @@ -134,7 +130,7 @@ impl Validator for UuidValidator { Self::EXPECTED_TYPE } - fn complete(&mut self, _definitions: &DefinitionsBuilder) -> PyResult<()> { + fn complete(&self) -> PyResult<()> { Ok(()) } } diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 6cf5ce313..79ec8b87a 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,25 +1,16 @@ -use crate::{definitions::Definitions, recursion_guard::RecursionGuard}; +use crate::recursion_guard::RecursionGuard; -use super::{CombinedValidator, Extra}; +use super::Extra; pub struct ValidationState<'a> { pub recursion_guard: &'a mut RecursionGuard, - pub definitions: &'a Definitions, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { - pub fn new( - extra: Extra<'a>, - definitions: &'a Definitions, - recursion_guard: &'a mut RecursionGuard, - ) -> Self { - Self { - recursion_guard, - definitions, - extra, - } + pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { + Self { recursion_guard, extra } } pub fn with_new_extra<'r, R: 'r>( @@ -31,7 +22,6 @@ impl<'a> ValidationState<'a> { // but lifetimes get in a tangle. Maybe someone brave wants to have a go at unpicking lifetimes. let mut new_state = ValidationState { recursion_guard: self.recursion_guard, - definitions: self.definitions, extra, }; f(&mut new_state) diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index d68590766..36b275dd1 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -66,7 +66,7 @@ enum OnError { Default, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct WithDefaultValidator { default: DefaultType, on_error: OnError, @@ -182,20 +182,16 @@ impl Validator for WithDefaultValidator { } } - fn different_strict_behavior( - &self, - definitions: Option<&DefinitionsBuilder>, - ultra_strict: bool, - ) -> bool { - self.validator.different_strict_behavior(definitions, ultra_strict) + fn different_strict_behavior(&self, ultra_strict: bool) -> bool { + self.validator.different_strict_behavior(ultra_strict) } fn get_name(&self) -> &str { &self.name } - fn complete(&mut self, definitions: &DefinitionsBuilder) -> PyResult<()> { - self.validator.complete(definitions) + fn complete(&self) -> PyResult<()> { + self.validator.complete() } } diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index b836eb7a1..2d676ac17 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -1,3 +1,4 @@ +import datetime import platform from dataclasses import dataclass from typing import List, Optional @@ -243,7 +244,7 @@ class Branch: def test_invalid_schema(): - with pytest.raises(SchemaError, match='Definitions error: attempted to use `Branch` before it was filled'): + with pytest.raises(SchemaError, match='Definitions error: definition `Branch` was never filled'): SchemaValidator( { 'type': 'list', @@ -895,3 +896,192 @@ class Model: 'url': f'https://errors.pydantic.dev/{pydantic_version}/v/dataclass_type', } ] + + +def test_cyclic_data() -> None: + cyclic_data = {} + cyclic_data['b'] = {'a': cyclic_data} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='b', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_cyclic_data_threeway() -> None: + cyclic_data = {} + cyclic_data['b'] = {'c': {'a': cyclic_data}} + + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('a'), + [ + core_schema.typed_dict_schema( + { + 'b': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('b')) + ) + }, + ref='a', + ), + core_schema.typed_dict_schema( + { + 'c': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('c')) + ) + }, + ref='b', + ), + core_schema.typed_dict_schema( + { + 'a': core_schema.typed_dict_field( + core_schema.nullable_schema(core_schema.definition_reference_schema('a')) + ) + }, + ref='c', + ), + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python(cyclic_data) + + assert exc_info.value.title == 'typed-dict' + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'recursion_loop', + 'loc': ('b', 'c', 'a'), + 'msg': 'Recursion error - cyclic reference detected', + 'input': cyclic_data, + } + ] + + +def test_complex_recursive_type() -> None: + schema = core_schema.definitions_schema( + core_schema.definition_reference_schema('JsonType'), + [ + core_schema.nullable_schema( + core_schema.union_schema( + [ + core_schema.list_schema(core_schema.definition_reference_schema('JsonType')), + core_schema.dict_schema( + core_schema.str_schema(), core_schema.definition_reference_schema('JsonType') + ), + core_schema.str_schema(), + core_schema.int_schema(), + core_schema.float_schema(), + core_schema.bool_schema(), + ] + ), + ref='JsonType', + ) + ], + ) + + validator = SchemaValidator(schema) + + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'a': datetime.date(year=1992, month=12, day=11)}) + + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'list_type', + 'loc': ('list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]',), + 'msg': 'Input should be a valid list', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'list_type', + 'loc': ('dict[str,...]', 'a', 'list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]'), + 'msg': 'Input should be a valid list', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'dict_type', + 'loc': ('dict[str,...]', 'a', 'dict[str,...]'), + 'msg': 'Input should be a valid dictionary', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('dict[str,...]', 'a', 'str'), + 'msg': 'Input should be a valid string', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'int_type', + 'loc': ('dict[str,...]', 'a', 'int'), + 'msg': 'Input should be a valid integer', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'float_type', + 'loc': ('dict[str,...]', 'a', 'float'), + 'msg': 'Input should be a valid number', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'bool_type', + 'loc': ('dict[str,...]', 'a', 'bool'), + 'msg': 'Input should be a valid boolean', + 'input': datetime.date(1992, 12, 11), + }, + { + 'type': 'string_type', + 'loc': ('str',), + 'msg': 'Input should be a valid string', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'int_type', + 'loc': ('int',), + 'msg': 'Input should be a valid integer', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'float_type', + 'loc': ('float',), + 'msg': 'Input should be a valid number', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + { + 'type': 'bool_type', + 'loc': ('bool',), + 'msg': 'Input should be a valid boolean', + 'input': {'a': datetime.date(1992, 12, 11)}, + }, + ] From f3c5714c47681264e1346c9ead3b7ad0f29e514e Mon Sep 17 00:00:00 2001 From: Sigurd Spieckermann <2206639+sisp@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:08:13 +0200 Subject: [PATCH 056/550] Fix backwards compatibility of type-checking when using deprecated `FieldValidationInfo` (#995) --- python/pydantic_core/core_schema.py | 3 +++ tests/test_typing.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 2d7061ffd..347b93239 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3909,6 +3909,9 @@ def general_plain_validator_function(*args, **kwargs): 'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction, } +if TYPE_CHECKING: + FieldValidationInfo = ValidationInfo + def __getattr__(attr_name: str) -> object: new_attr = _deprecated_import_lookup.get(attr_name) diff --git a/tests/test_typing.py b/tests/test_typing.py index 0d527c619..dcd2f267a 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -23,6 +23,10 @@ def foo(bar: str) -> None: ... +def validator_deprecated(value: Any, info: core_schema.FieldValidationInfo) -> None: + ... + + def validator(value: Any, info: core_schema.ValidationInfo) -> None: ... From 49126b0b981c6acc10dfb77f86aeaeb417a5979b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 29 Sep 2023 22:05:39 -0500 Subject: [PATCH 057/550] Add benchmark for nested/wide model using definitions (#997) --- benches/main.rs | 44 +++++++++++ tests/benchmarks/nested_schema.py | 93 +++++++++++++++++++++++ tests/benchmarks/test_nested_benchmark.py | 23 ++++++ 3 files changed, 160 insertions(+) create mode 100644 tests/benchmarks/nested_schema.py create mode 100644 tests/benchmarks/test_nested_benchmark.py diff --git a/benches/main.rs b/benches/main.rs index 9d46131d1..4b8a2b106 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -458,6 +458,50 @@ fn complete_model(bench: &mut Bencher) { }) } +#[bench] +fn nested_model_using_definitions(bench: &mut Bencher) { + Python::with_gil(|py| { + let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); + sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); + + let complete_schema = py.import("nested_schema").unwrap(); + let mut schema = complete_schema.call_method0("schema_using_defs").unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + + let input = complete_schema.call_method0("input_data_valid").unwrap(); + let input = black_box(input); + + validator.validate_python(py, input, None, None, None, None).unwrap(); + + bench.iter(|| { + black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + }) + }) +} + +#[bench] +fn nested_model_inlined(bench: &mut Bencher) { + Python::with_gil(|py| { + let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); + sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); + + let complete_schema = py.import("nested_schema").unwrap(); + let mut schema = complete_schema.call_method0("inlined_schema").unwrap(); + schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + + let input = complete_schema.call_method0("input_data_valid").unwrap(); + let input = black_box(input); + + validator.validate_python(py, input, None, None, None, None).unwrap(); + + bench.iter(|| { + black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + }) + }) +} + #[bench] fn literal_ints_few_python(bench: &mut Bencher) { Python::with_gil(|py| { diff --git a/tests/benchmarks/nested_schema.py b/tests/benchmarks/nested_schema.py new file mode 100644 index 000000000..0d91d1217 --- /dev/null +++ b/tests/benchmarks/nested_schema.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from pydantic_core import core_schema as cs + +N = 5 # arbitrary number that takes ~0.05s per run + + +class MyModel: + # __slots__ is not required, but it avoids __pydantic_fields_set__ falling into __dict__ + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + + +def schema_using_defs() -> cs.CoreSchema: + definitions: list[cs.CoreSchema] = [ + {'type': 'int', 'ref': 'int'}, + { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': { + str(c): {'type': 'model-field', 'schema': {'type': 'definition-ref', 'schema_ref': 'int'}} + for c in range(N) + }, + }, + 'ref': f'model_{N}', + }, + ] + level = N + for level in reversed(range(N)): + definitions.append( + { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': { + str(c): { + 'type': 'model-field', + 'schema': {'type': 'definition-ref', 'schema_ref': f'model_{level+1}'}, + } + for c in range(N) + }, + }, + 'ref': f'model_{level}', + } + ) + return { + 'type': 'definitions', + 'definitions': definitions, + 'schema': {'type': 'definition-ref', 'schema_ref': 'model_0'}, + } + + +def inlined_schema() -> cs.CoreSchema: + level = N + schema: cs.CoreSchema = { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': {str(c): {'type': 'model-field', 'schema': {'type': 'int'}} for c in range(N)}, + }, + 'ref': f'model_{N}', + } + for level in reversed(range(N)): + schema = { + 'type': 'model', + 'cls': MyModel, + 'schema': { + 'type': 'model-fields', + 'fields': {str(c): {'type': 'model-field', 'schema': schema} for c in range(N)}, + }, + 'ref': f'model_{level}', + } + return schema + + +def input_data_valid(levels: int = N) -> Any: + data = {str(c): 1 for c in range(N)} + for _ in range(levels): + data = {str(c): data for c in range(N)} + return data + + +if __name__ == '__main__': + from pydantic_core import SchemaValidator + + SchemaValidator(schema_using_defs()).validate_python(input_data_valid()) + SchemaValidator(inlined_schema()).validate_python(input_data_valid()) diff --git a/tests/benchmarks/test_nested_benchmark.py b/tests/benchmarks/test_nested_benchmark.py new file mode 100644 index 000000000..6c8d50e83 --- /dev/null +++ b/tests/benchmarks/test_nested_benchmark.py @@ -0,0 +1,23 @@ +""" +Benchmarks for nested / recursive schemas using definitions. +""" + +from typing import Callable + +from pydantic_core import SchemaValidator + +from .nested_schema import inlined_schema, input_data_valid, schema_using_defs + + +def test_nested_schema_using_defs(benchmark: Callable[..., None]) -> None: + v = SchemaValidator(schema_using_defs()) + data = input_data_valid() + v.validate_python(data) + benchmark(v.validate_python, data) + + +def test_nested_schema_inlined(benchmark: Callable[..., None]) -> None: + v = SchemaValidator(inlined_schema()) + data = input_data_valid() + v.validate_python(data) + benchmark(v.validate_python, data) From b5fb6875b201c9d0764f2c6931ac5618b81329f1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 12:42:35 +0100 Subject: [PATCH 058/550] Bump griffe from 0.36.2 to 0.36.4 (#1000) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 48e914dbb..bece1c8ae 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 -griffe==0.36.2 +griffe==0.36.4 pyright==1.1.327 ruff==0.0.291 mypy==1.5.1 From a8ac5b1baffc28ba8f54ba04ad7c1d28f50527e2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 11:50:47 +0000 Subject: [PATCH 059/550] Bump pyright from 1.1.327 to 1.1.329 (#998) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index bece1c8ae..d1ae42ff5 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.4 -pyright==1.1.327 +pyright==1.1.329 ruff==0.0.291 mypy==1.5.1 From 4622ed72d10b97dffba5d9fc46020f7510680a74 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:17:34 +0100 Subject: [PATCH 060/550] Bump regex from 1.9.5 to 1.9.6 (#1001) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a2db9901f..0898b6b1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -355,9 +355,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.5" +version = "1.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" dependencies = [ "aho-corasick", "memchr", @@ -367,9 +367,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" dependencies = [ "aho-corasick", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 7101a7019..9d18e224d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } -regex = "1.9.5" +regex = "1.9.6" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} From d93482e3b83f33cc134e8aa83992467c6dda8395 Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 2 Oct 2023 16:10:20 -0500 Subject: [PATCH 061/550] Fix pydantic 7715 (#1002) Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- src/validators/dataclass.rs | 38 ++++--- src/validators/model_fields.rs | 36 +++++-- src/validators/typed_dict.rs | 38 +++++-- tests/validators/test_with_default.py | 150 ++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 31 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index dff7735a3..2706cadef 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -232,19 +232,31 @@ impl Validator for DataclassArgsValidator { } // found neither, check if there is a default value, otherwise error (None, None) => { - if let Some(value) = - field - .validator - .default_value(py, Some(field.name.as_str()), state)? - { - set_item!(field, value); - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name, - )); + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + set_item!(field, value); + }, + Ok(None) => { + // This means there was no default value + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return Err(err), } } } diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 30d937cb8..774d3eef9 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -211,15 +211,33 @@ impl Validator for ModelFieldsValidator { Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; - } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { - control_flow!(model_dict.set_item(&field.name_py, value))?; - } else { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); + } + + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + control_flow!(model_dict.set_item(&field.name_py, value))?; + }, + Ok(None) => { + // This means there was no default value + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return ControlFlow::Break(err), } } ControlFlow::Continue(()) diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 1a5b52dc6..118d992a2 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -212,15 +212,35 @@ impl Validator for TypedDictValidator { Err(err) => return ControlFlow::Break(err.into_owned(py)), } continue; - } else if let Some(value) = control_flow!(field.validator.default_value(py, Some(field.name.as_str()), state))? { - control_flow!(output_dict.set_item(&field.name_py, value))?; - } else if field.required { - errors.push(field.lookup_key.error( - ErrorTypeDefaults::Missing, - input, - self.loc_by_alias, - &field.name - )); + } + + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + control_flow!(output_dict.set_item(&field.name_py, value))?; + }, + Ok(None) => { + // This means there was no default value + if (field.required) { + errors.push(field.lookup_key.error( + ErrorTypeDefaults::Missing, + input, + self.loc_by_alias, + &field.name + )); + } + }, + Err(ValError::Omit) => continue, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + errors.push(err); + } + } + Err(err) => return ControlFlow::Break(err), } } ControlFlow::Continue(()) diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 808e4807d..7ca0d9f54 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -654,3 +654,153 @@ def _validator(cls, v, info): gc.collect() assert ref() is None + + +validate_default_raises_examples = [ + ( + {}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {}}, + ], + ), + ( + {'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None}}, + ], + ), + ( + {'x': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'y': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'y': None}}, + ], + ), + ( + {'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None, 'y': None}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'missing', 'loc': ('z',), 'msg': 'Field required', 'input': {'x': None, 'y': None}}, + ], + ), + ( + {'x': None, 'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': 1, 'y': None, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': None}, + ], + ), + ( + {'x': None, 'y': 1, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': None}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1}, + ], + ), + ( + {'x': 1, 'y': 1, 'z': 'some str'}, + [ + {'type': 'assertion_error', 'loc': ('x',), 'msg': 'Assertion failed, ', 'input': 1}, + {'type': 'assertion_error', 'loc': ('y',), 'msg': 'Assertion failed, ', 'input': 1}, + ], + ), +] + + +@pytest.mark.parametrize( + 'core_schema_constructor,field_constructor', + [ + (core_schema.model_fields_schema, core_schema.model_field), + (core_schema.typed_dict_schema, core_schema.typed_dict_field), + ], +) +@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples) +def test_validate_default_raises( + core_schema_constructor: Union[core_schema.ModelFieldsSchema, core_schema.TypedDictSchema], + field_constructor: Union[core_schema.model_field, core_schema.typed_dict_field], + input_value: dict, + expected: Any, +) -> None: + def _raise(ex: Exception) -> None: + raise ex() + + inner_schema = core_schema.no_info_after_validator_function( + lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema()) + ) + + v = SchemaValidator( + core_schema_constructor( + { + 'x': field_constructor( + core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ), + 'y': field_constructor( + core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ), + 'z': field_constructor(core_schema.str_schema()), + } + ) + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(input_value) + assert exc_info.value.errors(include_url=False, include_context=False) == expected + + +@pytest.mark.parametrize('input_value,expected', validate_default_raises_examples) +def test_validate_default_raises_dataclass(input_value: dict, expected: Any) -> None: + def _raise(ex: Exception) -> None: + raise ex() + + inner_schema = core_schema.no_info_after_validator_function( + lambda x: _raise(AssertionError), core_schema.nullable_schema(core_schema.int_schema()) + ) + + x = core_schema.dataclass_field( + name='x', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ) + y = core_schema.dataclass_field( + name='y', schema=core_schema.with_default_schema(inner_schema, default=None, validate_default=True) + ) + z = core_schema.dataclass_field(name='z', schema=core_schema.str_schema()) + + v = SchemaValidator(core_schema.dataclass_args_schema('XYZ', [x, y, z])) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python(input_value) + + assert exc_info.value.errors(include_url=False, include_context=False) == expected From 493621705d5a19b6f4cbf3b7591d61f62f3029f5 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:39:31 +0100 Subject: [PATCH 062/550] Fix new lint error from Rust 1.73 (#1010) --- src/validators/url.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/validators/url.rs b/src/validators/url.rs index 4584ae652..b5fe8bf4c 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -498,7 +498,7 @@ fn check_sub_defaults( if let Some(default_port) = default_port { lib_url .set_port(Some(default_port)) - .map_err(|_| map_parse_err(ParseError::EmptyHost))?; + .map_err(|()| map_parse_err(ParseError::EmptyHost))?; } } if let Some(ref default_path) = default_path { From 30444546a3a0ca54b9762a660f65f15e0bd52db3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:12:04 +0100 Subject: [PATCH 063/550] Bump ruff from 0.0.291 to 0.0.292 (#1009) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index d1ae42ff5..f16ee3288 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.4 pyright==1.1.329 -ruff==0.0.291 +ruff==0.0.292 mypy==1.5.1 From 172f6fcf99d2b7753e8f8eb61b8e02465491472e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:12:25 +0100 Subject: [PATCH 064/550] Bump pytest-timeout from 2.1.0 to 2.2.0 (#1008) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index ae8b9e50d..98cef770f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,7 +10,7 @@ pytest-examples==0.0.10 pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 -pytest-timeout==2.1.0 +pytest-timeout==2.2.0 pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' From 8bbfb33bc8089f55715a5d1e6ea05a0ef43c4fc0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 9 Oct 2023 13:20:44 +0000 Subject: [PATCH 065/550] Bump pyright from 1.1.329 to 1.1.330.post0 (#1007) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index f16ee3288..25b74ce08 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.4 -pyright==1.1.329 +pyright==1.1.330.post0 ruff==0.0.292 mypy==1.5.1 From 8e66bd94758ab8b42d6645f29b7ce954e565ebf3 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:19:01 +0100 Subject: [PATCH 066/550] Fix `regex_engine` being rejected by `validate_core_schema` (#1011) --- python/pydantic_core/core_schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 347b93239..adb8e5647 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -106,6 +106,7 @@ class CoreConfig(TypedDict, total=False): hide_input_in_errors: bool validation_error_cause: bool # default: False coerce_numbers_to_str: bool # default: False + regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' From b51105a4d459d93ae33ebc1f22219be71b03786f Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Mon, 9 Oct 2023 12:13:01 -0500 Subject: [PATCH 067/550] Add `SchemaSerializer.__reduce__` method to enable `pickle` serialization (#1006) Signed-off-by: Edward Oakes --- src/serializers/mod.rs | 26 +++++++++++++-- src/validators/mod.rs | 32 +++++++++++++----- tests/serializers/test_pickling.py | 50 ++++++++++++++++++++++++++++ tests/test.rs | 4 +-- tests/test_garbage_collection.py | 9 +++-- tests/validators/test_datetime.py | 12 ------- tests/validators/test_pickling.py | 53 ++++++++++++++++++++++++++++++ 7 files changed, 158 insertions(+), 28 deletions(-) create mode 100644 tests/serializers/test_pickling.py create mode 100644 tests/validators/test_pickling.py diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 72028346b..00d70162f 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -26,13 +26,17 @@ mod ob_type; mod shared; mod type_serializers; -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaSerializer { serializer: CombinedSerializer, definitions: Definitions, expected_json_size: AtomicUsize, config: SerializationConfig, + // References to the Python schema and config objects are saved to enable + // reconstructing the object for pickle support (see `__reduce__`). + py_schema: Py, + py_config: Option>, } impl SchemaSerializer { @@ -71,15 +75,19 @@ impl SchemaSerializer { #[pymethods] impl SchemaSerializer { #[new] - pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult { + pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult { let mut definitions_builder = DefinitionsBuilder::new(); - let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?; Ok(Self { serializer, definitions: definitions_builder.finish()?, expected_json_size: AtomicUsize::new(1024), config: SerializationConfig::from_config(config)?, + py_schema: schema.into_py(py), + py_config: match config { + Some(c) if !c.is_empty() => Some(c.into_py(py)), + _ => None, + }, }) } @@ -174,6 +182,14 @@ impl SchemaSerializer { Ok(py_bytes.into()) } + pub fn __reduce__(slf: &PyCell) -> PyResult<(PyObject, (PyObject, PyObject))> { + // Enables support for `pickle` serialization. + let py = slf.py(); + let cls = slf.get_type().into(); + let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py)); + Ok((cls, init_args)) + } + pub fn __repr__(&self) -> String { format!( "SchemaSerializer(serializer={:#?}, definitions={:#?})", @@ -182,6 +198,10 @@ impl SchemaSerializer { } fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.py_schema)?; + if let Some(ref py_config) = self.py_config { + visit.call(py_config)?; + } self.serializer.py_gc_traverse(&visit)?; self.definitions.py_gc_traverse(&visit)?; Ok(()) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 42aad2001..5db192f3a 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -97,12 +97,15 @@ impl PySome { } } -#[pyclass(module = "pydantic_core._pydantic_core")] +#[pyclass(module = "pydantic_core._pydantic_core", frozen)] #[derive(Debug)] pub struct SchemaValidator { validator: CombinedValidator, definitions: Definitions, - schema: PyObject, + // References to the Python schema and config objects are saved to enable + // reconstructing the object for cloudpickle support (see `__reduce__`). + py_schema: Py, + py_config: Option>, #[pyo3(get)] title: PyObject, hide_input_in_errors: bool, @@ -121,6 +124,11 @@ impl SchemaValidator { for val in definitions.values() { val.get().unwrap().complete()?; } + let py_schema = schema.into_py(py); + let py_config = match config { + Some(c) if !c.is_empty() => Some(c.into_py(py)), + _ => None, + }; let config_title = match config { Some(c) => c.get_item("title"), None => None, @@ -134,18 +142,20 @@ impl SchemaValidator { Ok(Self { validator, definitions, - schema: schema.into_py(py), + py_schema, + py_config, title, hide_input_in_errors, validation_error_cause, }) } - pub fn __reduce__(slf: &PyCell) -> PyResult { + pub fn __reduce__(slf: &PyCell) -> PyResult<(PyObject, (PyObject, PyObject))> { + // Enables support for `pickle` serialization. let py = slf.py(); - let args = (slf.try_borrow()?.schema.to_object(py),); - let cls = slf.getattr("__class__")?; - Ok((cls, args).into_py(py)) + let cls = slf.get_type().into(); + let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py)); + Ok((cls, init_args)) } #[pyo3(signature = (input, *, strict=None, from_attributes=None, context=None, self_instance=None))] @@ -307,7 +317,10 @@ impl SchemaValidator { fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { self.validator.py_gc_traverse(&visit)?; - visit.call(&self.schema)?; + visit.call(&self.py_schema)?; + if let Some(ref py_config) = self.py_config { + visit.call(py_config)?; + } Ok(()) } } @@ -396,7 +409,8 @@ impl<'py> SelfValidator<'py> { Ok(SchemaValidator { validator, definitions, - schema: py.None(), + py_schema: py.None(), + py_config: None, title: "Self Schema".into_py(py), hide_input_in_errors: false, validation_error_cause: false, diff --git a/tests/serializers/test_pickling.py b/tests/serializers/test_pickling.py new file mode 100644 index 000000000..2ca230313 --- /dev/null +++ b/tests/serializers/test_pickling.py @@ -0,0 +1,50 @@ +import json +import pickle +from datetime import timedelta + +import pytest + +from pydantic_core import core_schema +from pydantic_core._pydantic_core import SchemaSerializer + + +def repr_function(value, _info): + return repr(value) + + +def test_basic_schema_serializer(): + s = SchemaSerializer(core_schema.dict_schema()) + s = pickle.loads(pickle.dumps(s)) + assert s.to_python({'a': 1, b'b': 2, 33: 3}) == {'a': 1, b'b': 2, 33: 3} + assert s.to_python({'a': 1, b'b': 2, 33: 3, True: 4}, mode='json') == {'a': 1, 'b': 2, '33': 3, 'true': 4} + assert s.to_json({'a': 1, b'b': 2, 33: 3, True: 4}) == b'{"a":1,"b":2,"33":3,"true":4}' + + assert s.to_python({(1, 2): 3}) == {(1, 2): 3} + assert s.to_python({(1, 2): 3}, mode='json') == {'1,2': 3} + assert s.to_json({(1, 2): 3}) == b'{"1,2":3}' + + +@pytest.mark.parametrize( + 'value,expected_python,expected_json', + [(None, 'None', b'"None"'), (1, '1', b'"1"'), ([1, 2, 3], '[1, 2, 3]', b'"[1, 2, 3]"')], +) +def test_schema_serializer_capturing_function(value, expected_python, expected_json): + # Test a SchemaSerializer that captures a function. + s = SchemaSerializer( + core_schema.any_schema( + serialization=core_schema.plain_serializer_function_ser_schema(repr_function, info_arg=True) + ) + ) + s = pickle.loads(pickle.dumps(s)) + assert s.to_python(value) == expected_python + assert s.to_json(value) == expected_json + assert s.to_python(value, mode='json') == json.loads(expected_json) + + +def test_schema_serializer_containing_config(): + s = SchemaSerializer(core_schema.timedelta_schema(), config={'ser_json_timedelta': 'float'}) + s = pickle.loads(pickle.dumps(s)) + + assert s.to_python(timedelta(seconds=4, microseconds=500_000)) == timedelta(seconds=4, microseconds=500_000) + assert s.to_python(timedelta(seconds=4, microseconds=500_000), mode='json') == 4.5 + assert s.to_json(timedelta(seconds=4, microseconds=500_000)) == b'4.5' diff --git a/tests/test.rs b/tests/test.rs index 526b30e5e..9b2fb99b5 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -46,7 +46,7 @@ mod tests { ] }"#; let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap(); - SchemaSerializer::py_new(schema, None).unwrap(); + SchemaSerializer::py_new(py, schema, None).unwrap(); }); } @@ -77,7 +77,7 @@ a = A() py.run(code, None, Some(locals)).unwrap(); let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap(); let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap(); - let serialized: Vec = SchemaSerializer::py_new(schema, None) + let serialized: Vec = SchemaSerializer::py_new(py, schema, None) .unwrap() .to_json(py, a, None, None, None, true, false, false, false, false, true, None) .unwrap() diff --git a/tests/test_garbage_collection.py b/tests/test_garbage_collection.py index d848c91ea..97107e61b 100644 --- a/tests/test_garbage_collection.py +++ b/tests/test_garbage_collection.py @@ -27,7 +27,9 @@ class BaseModel: __schema__: SchemaSerializer def __init_subclass__(cls) -> None: - cls.__schema__ = SchemaSerializer(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) + cls.__schema__ = SchemaSerializer( + core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), config={'ser_json_timedelta': 'float'} + ) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() @@ -56,7 +58,10 @@ class BaseModel: __validator__: SchemaValidator def __init_subclass__(cls) -> None: - cls.__validator__ = SchemaValidator(core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER)) + cls.__validator__ = SchemaValidator( + core_schema.model_schema(cls, GC_TEST_SCHEMA_INNER), + config=core_schema.CoreConfig(extra_fields_behavior='allow'), + ) cache: 'WeakValueDictionary[int, Any]' = WeakValueDictionary() diff --git a/tests/validators/test_datetime.py b/tests/validators/test_datetime.py index df04d1631..67581119b 100644 --- a/tests/validators/test_datetime.py +++ b/tests/validators/test_datetime.py @@ -1,6 +1,5 @@ import copy import json -import pickle import platform import re from datetime import date, datetime, time, timedelta, timezone, tzinfo @@ -480,17 +479,6 @@ def test_tz_constraint_wrong(): validate_core_schema(core_schema.datetime_schema(tz_constraint='wrong')) -def test_tz_pickle() -> None: - """ - https://github.com/pydantic/pydantic-core/issues/589 - """ - v = SchemaValidator(core_schema.datetime_schema()) - original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15))) - validated = v.validate_python('2022-06-08T12:13:14-12:15') - assert validated == original - assert pickle.loads(pickle.dumps(validated)) == validated == original - - def test_tz_hash() -> None: v = SchemaValidator(core_schema.datetime_schema()) lookup: Dict[datetime, str] = {} diff --git a/tests/validators/test_pickling.py b/tests/validators/test_pickling.py new file mode 100644 index 000000000..2037ab8c9 --- /dev/null +++ b/tests/validators/test_pickling.py @@ -0,0 +1,53 @@ +import pickle +import re +from datetime import datetime, timedelta, timezone + +import pytest + +from pydantic_core import core_schema, validate_core_schema +from pydantic_core._pydantic_core import SchemaValidator, ValidationError + + +def test_basic_schema_validator(): + v = SchemaValidator( + validate_core_schema( + {'type': 'dict', 'strict': True, 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}} + ) + ) + v = pickle.loads(pickle.dumps(v)) + assert v.validate_python({'1': 2, '3': 4}) == {1: 2, 3: 4} + assert v.validate_python({}) == {} + with pytest.raises(ValidationError, match=re.escape('[type=dict_type, input_value=[], input_type=list]')): + v.validate_python([]) + + +def test_schema_validator_containing_config(): + """ + Verify that the config object is not lost during (de)serialization. + """ + v = SchemaValidator( + core_schema.model_fields_schema({'f': core_schema.model_field(core_schema.str_schema())}), + config=core_schema.CoreConfig(extra_fields_behavior='allow'), + ) + v = pickle.loads(pickle.dumps(v)) + + m, model_extra, fields_set = v.validate_python({'f': 'x', 'extra_field': '123'}) + assert m == {'f': 'x'} + # If the config was lost during (de)serialization, the below checks would fail as + # the default behavior is to ignore extra fields. + assert model_extra == {'extra_field': '123'} + assert fields_set == {'f', 'extra_field'} + + v.validate_assignment(m, 'f', 'y') + assert m == {'f': 'y'} + + +def test_schema_validator_tz_pickle() -> None: + """ + https://github.com/pydantic/pydantic-core/issues/589 + """ + v = SchemaValidator(core_schema.datetime_schema()) + original = datetime(2022, 6, 8, 12, 13, 14, tzinfo=timezone(timedelta(hours=-12, minutes=-15))) + validated = v.validate_python('2022-06-08T12:13:14-12:15') + assert validated == original + assert pickle.loads(pickle.dumps(validated)) == validated == original From f67e25c2c03a33bf60a18e66718df7e23fb0e79f Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 12 Oct 2023 07:56:25 -0500 Subject: [PATCH 068/550] Fix `definition-ref` bug with `Dict` keys (#1014) --- .../type_serializers/definitions.rs | 2 +- tests/serializers/test_definitions.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index cf92df244..b7bf63365 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -76,7 +76,7 @@ impl TypeSerializer for DefinitionRefSerializer { } fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { - self._invalid_as_json_key(key, extra, Self::EXPECTED_TYPE) + self.definition.get().unwrap().json_key(key, extra) } fn serde_serialize( diff --git a/tests/serializers/test_definitions.py b/tests/serializers/test_definitions.py index 2da4d353d..d45398097 100644 --- a/tests/serializers/test_definitions.py +++ b/tests/serializers/test_definitions.py @@ -113,3 +113,24 @@ def test_use_after(): ) ) assert v.to_python((1, 2)) == ('1', '2') + + +def test_defs_with_dict(): + s = SchemaSerializer( + core_schema.definitions_schema( + schema=core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field( + core_schema.dict_schema( + keys_schema=core_schema.definition_reference_schema('key'), + values_schema=core_schema.definition_reference_schema('val'), + ) + ) + } + ), + definitions=[core_schema.str_schema(ref='key'), core_schema.str_schema(ref='val')], + ) + ) + + assert s.to_json({'foo': {'key': 'val'}}) == b'{"foo":{"key":"val"}}' + assert s.to_python({'foo': {'key': 'val'}}) == {'foo': {'key': 'val'}} From 372904c53c94aa410e4778755dcb4b2567a3a4d9 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Sat, 14 Oct 2023 21:13:41 +0100 Subject: [PATCH 069/550] Update PyO3 to 0.20 (#1003) --- Cargo.lock | 50 ++++++++----------- Cargo.toml | 6 +-- src/errors/types.rs | 2 +- src/errors/validation_exception.rs | 8 +-- src/errors/value_exception.rs | 4 +- src/input/input_python.rs | 14 +++--- src/input/return_enums.rs | 2 +- src/lookup_key.rs | 6 +-- src/serializers/extra.rs | 2 +- src/serializers/filter.rs | 10 ++-- src/serializers/type_serializers/dict.rs | 4 +- src/serializers/type_serializers/model.rs | 2 +- .../type_serializers/typed_dict.rs | 2 +- src/tools.rs | 4 +- src/validators/arguments.rs | 6 +-- src/validators/bytes.rs | 4 +- src/validators/call.rs | 2 +- src/validators/dataclass.rs | 6 +-- src/validators/datetime.rs | 2 +- src/validators/decimal.rs | 2 +- src/validators/dict.rs | 4 +- src/validators/float.rs | 10 ++-- src/validators/int.rs | 10 ++-- src/validators/list.rs | 2 +- src/validators/literal.rs | 2 +- src/validators/mod.rs | 2 +- src/validators/model_fields.rs | 4 +- src/validators/set.rs | 2 +- src/validators/string.rs | 8 +-- src/validators/timedelta.rs | 2 +- src/validators/tuple.rs | 2 +- src/validators/typed_dict.rs | 4 +- src/validators/union.rs | 4 +- tests/test.rs | 4 +- 34 files changed, 94 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0898b6b1f..3c2af0c54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,7 +62,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.28", + "syn", ] [[package]] @@ -125,9 +125,9 @@ dependencies = [ [[package]] name = "indoc" -version = "1.0.9" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" [[package]] name = "itoa" @@ -266,9 +266,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" dependencies = [ "cfg-if", "indoc", @@ -284,9 +284,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076c73d0bc438f7a4ef6fdd0c3bb4732149136abd952b110ac93e4edb13a6ba5" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" dependencies = [ "once_cell", "python3-dll-a", @@ -295,9 +295,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53cee42e77ebe256066ba8aa77eff722b3bb91f3419177cf4cd0f304d3284d9" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" dependencies = [ "libc", "pyo3-build-config", @@ -305,25 +305,26 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfeb4c99597e136528c6dd7d5e3de5434d1ceaf487436a3f03b2d56b6fc9efd1" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "pyo3-macros-backend" -version = "0.19.2" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dc12175c254889edc0c02e399476c2f652b4b9ebd123aa655c224de259536" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" dependencies = [ + "heck", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] @@ -417,7 +418,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.28", + "syn", ] [[package]] @@ -467,18 +468,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.28", -] - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", + "syn", ] [[package]] @@ -536,9 +526,9 @@ dependencies = [ [[package]] name = "unindent" -version = "0.1.11" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "url" diff --git a/Cargo.toml b/Cargo.toml index 9d18e224d..4a6541781 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ include = [ ] [dependencies] -pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] } +pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } regex = "1.9.6" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.2" @@ -62,9 +62,9 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version= "0.19.2", features = ["auto-initialize"] } +pyo3 = { version = "0.20.0", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy -pyo3-build-config = { version = "0.19.2" } +pyo3-build-config = { version = "0.20.0" } diff --git a/src/errors/types.rs b/src/errors/types.rs index d537158ba..e31307e2b 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -50,7 +50,7 @@ fn field_from_context<'py, T: FromPyObject<'py>>( ) -> PyResult { context .ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))? - .get_item(field_name) + .get_item(field_name)? .ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))? .extract::() .map_err(|_| py_error_type!(PyTypeError; "{}: '{}' context value must be a {}", enum_name, field_name, type_name_fn())) diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 09154d8ac..cb3802d56 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -4,10 +4,10 @@ use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; use pyo3::ffi; +use pyo3::intern; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString}; -use pyo3::{intern, AsPyPointer}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -445,7 +445,7 @@ impl TryFrom<&PyAny> for PyLineError { let py = value.py(); let type_raw = dict - .get_item(intern!(py, "type")) + .get_item(intern!(py, "type"))? .ok_or_else(|| PyKeyError::new_err("type"))?; let error_type = if let Ok(type_str) = type_raw.downcast::() { @@ -459,9 +459,9 @@ impl TryFrom<&PyAny> for PyLineError { )); }; - let location = Location::try_from(dict.get_item("loc"))?; + let location = Location::try_from(dict.get_item("loc")?)?; - let input_value = match dict.get_item("input") { + let input_value = match dict.get_item("input")? { Some(i) => i.into_py(py), None => py.None(), }; diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index f7d877b30..7bc7e5227 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -72,7 +72,7 @@ impl PydanticCustomError { } } - #[getter(type)] + #[getter(r#type)] pub fn error_type(&self) -> String { self.error_type.clone() } @@ -147,7 +147,7 @@ impl PydanticKnownError { Ok(Self { error_type }) } - #[getter(type)] + #[getter(r#type)] pub fn error_type(&self) -> String { self.error_type.to_string() } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index cf84c5517..cd8ecf267 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -8,7 +8,7 @@ use pyo3::types::{ }; #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; -use pyo3::{intern, AsPyPointer, PyTypeInfo}; +use pyo3::{intern, PyTypeInfo}; use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; @@ -32,7 +32,7 @@ macro_rules! extract_dict_keys { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -40,7 +40,7 @@ macro_rules! extract_dict_keys { macro_rules! extract_dict_keys { ($py:expr, $obj:ident) => { if is_dict_keys_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } @@ -52,7 +52,7 @@ macro_rules! extract_dict_values { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -60,7 +60,7 @@ macro_rules! extract_dict_values { macro_rules! extract_dict_values { ($py:expr, $obj:ident) => { if is_dict_values_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } @@ -72,7 +72,7 @@ macro_rules! extract_dict_items { ($py:expr, $obj:ident) => { $obj.downcast::() .ok() - .map(|v| PyIterator::from_object($py, v).unwrap()) + .map(|v| PyIterator::from_object(v).unwrap()) }; } @@ -80,7 +80,7 @@ macro_rules! extract_dict_items { macro_rules! extract_dict_items { ($py:expr, $obj:ident) => { if is_dict_items_type($obj) { - Some(PyIterator::from_object($py, $obj).unwrap()) + Some(PyIterator::from_object($obj).unwrap()) } else { None } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index c492f40f0..daa9f39fe 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -13,7 +13,7 @@ use pyo3::types::{ PyByteArray, PyBytes, PyDict, PyFloat, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, PyTuple, }; -use pyo3::{ffi, intern, AsPyPointer, PyNativeType}; +use pyo3::{ffi, intern, PyNativeType}; #[cfg(not(PyPy))] use pyo3::types::PyFunction; diff --git a/src/lookup_key.rs b/src/lookup_key.rs index 36190c069..bb7d7e3d7 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -111,7 +111,7 @@ impl LookupKey { dict: &'data PyDict, ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { match self { - Self::Simple { py_key, path, .. } => match dict.get_item(py_key) { + Self::Simple { py_key, path, .. } => match dict.get_item(py_key)? { Some(value) => Ok(Some((path, value))), None => Ok(None), }, @@ -121,9 +121,9 @@ impl LookupKey { py_key2, path2, .. - } => match dict.get_item(py_key1) { + } => match dict.get_item(py_key1)? { Some(value) => Ok(Some((path1, value))), - None => match dict.get_item(py_key2) { + None => match dict.get_item(py_key2)? { Some(value) => Ok(Some((path2, value))), None => Ok(None), }, diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 65c5a1ba9..7a9b84704 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -2,8 +2,8 @@ use std::cell::RefCell; use std::fmt; use pyo3::exceptions::PyValueError; +use pyo3::intern; use pyo3::prelude::*; -use pyo3::{intern, AsPyPointer}; use serde::ser::Error; diff --git a/src/serializers/filter.rs b/src/serializers/filter.rs index 89f923552..0efec56e8 100644 --- a/src/serializers/filter.rs +++ b/src/serializers/filter.rs @@ -60,8 +60,8 @@ impl SchemaFilter { let py = schema.py(); match schema.get_as::<&PyDict>(intern!(py, "serialization"))? { Some(ser) => { - let include = Self::build_set_ints(ser.get_item(intern!(py, "include")))?; - let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude")))?; + let include = Self::build_set_ints(ser.get_item(intern!(py, "include"))?)?; + let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude"))?)?; Ok(Self { include, exclude }) } None => Ok(SchemaFilter::default()), @@ -325,8 +325,8 @@ fn is_ellipsis_like(v: &PyAny) -> bool { /// lookup the dict, for the key and "__all__" key, and merge them following the same rules as pydantic V1 fn merge_all_value(dict: &PyDict, py_key: impl ToPyObject + Copy) -> PyResult> { - let op_item_value = dict.get_item(py_key); - let op_all_value = dict.get_item(intern!(dict.py(), "__all__")); + let op_item_value = dict.get_item(py_key)?; + let op_all_value = dict.get_item(intern!(dict.py(), "__all__"))?; match (op_item_value, op_all_value) { (Some(item_value), Some(all_value)) => { @@ -365,7 +365,7 @@ fn merge_dicts<'py>(item_dict: &'py PyDict, all_value: &'py PyAny) -> PyResult<& let item_dict = item_dict.copy()?; if let Ok(all_dict) = all_value.downcast::() { for (all_key, all_value) in all_dict { - if let Some(item_value) = item_dict.get_item(all_key) { + if let Some(item_value) = item_dict.get_item(all_key)? { if is_ellipsis_like(item_value) { continue; } diff --git a/src/serializers/type_serializers/dict.rs b/src/serializers/type_serializers/dict.rs index bb2a18633..89851752e 100644 --- a/src/serializers/type_serializers/dict.rs +++ b/src/serializers/type_serializers/dict.rs @@ -43,8 +43,8 @@ impl BuildSerializer for DictSerializer { }; let filter = match schema.get_as::<&PyDict>(intern!(py, "serialization"))? { Some(ser) => { - let include = ser.get_item(intern!(py, "include")); - let exclude = ser.get_item(intern!(py, "exclude")); + let include = ser.get_item(intern!(py, "include"))?; + let exclude = ser.get_item(intern!(py, "exclude"))?; SchemaFilter::from_set_hash(include, exclude)? } None => SchemaFilter::default(), diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 8a2eeb4e1..0d2d1d346 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -39,7 +39,7 @@ impl BuildSerializer for ModelFieldsBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); - let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) { (Some(v), FieldsMode::ModelExtra) => Some(CombinedSerializer::build(v.extract()?, config, definitions)?), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, diff --git a/src/serializers/type_serializers/typed_dict.rs b/src/serializers/type_serializers/typed_dict.rs index 5967738ae..fbef3486a 100644 --- a/src/serializers/type_serializers/typed_dict.rs +++ b/src/serializers/type_serializers/typed_dict.rs @@ -35,7 +35,7 @@ impl BuildSerializer for TypedDictBuilder { let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?; let mut fields: AHashMap = AHashMap::with_capacity(fields_dict.len()); - let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) { + let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) { (Some(v), FieldsMode::TypedDictAllow) => { Some(CombinedSerializer::build(v.extract()?, config, definitions)?) } diff --git a/src/tools.rs b/src/tools.rs index 3c75decf1..af58131f5 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -20,7 +20,7 @@ impl<'py> SchemaDict<'py> for PyDict { where T: FromPyObject<'py>, { - match self.get_item(key) { + match self.get_item(key)? { Some(t) => Ok(Some(::extract(t)?)), None => Ok(None), } @@ -30,7 +30,7 @@ impl<'py> SchemaDict<'py> for PyDict { where T: FromPyObject<'py>, { - match self.get_item(key) { + match self.get_item(key)? { Some(t) => ::extract(t), None => py_err!(PyKeyError; "{}", key), } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index aa0870043..7f406ba16 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -66,7 +66,7 @@ impl BuildValidator for ArgumentsValidator { let mut kw_lookup_key = None; let mut kwarg_key = None; if mode == "keyword_only" || mode == "positional_or_keyword" { - kw_lookup_key = match arg.get_item(intern!(py, "alias")) { + kw_lookup_key = match arg.get_item(intern!(py, "alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(name.as_str()) } else { None }; Some(LookupKey::from_py(py, alias, alt_alias)?) @@ -110,11 +110,11 @@ impl BuildValidator for ArgumentsValidator { Ok(Self { parameters, positional_params_count, - var_args_validator: match schema.get_item(intern!(py, "var_args_schema")) { + var_args_validator: match schema.get_item(intern!(py, "var_args_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, - var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema")) { + var_kwargs_validator: match schema.get_item(intern!(py, "var_kwargs_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 0f662af1c..fb90187ac 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -24,8 +24,8 @@ impl BuildValidator for BytesValidator { _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "max_length")).is_some() - || schema.get_item(intern!(py, "min_length")).is_some(); + let use_constrained = schema.get_item(intern!(py, "max_length"))?.is_some() + || schema.get_item(intern!(py, "min_length"))?.is_some(); if use_constrained { BytesConstrainedValidator::build(schema, config) } else { diff --git a/src/validators/call.rs b/src/validators/call.rs index 3f6eb4d35..eca1f0206 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -32,7 +32,7 @@ impl BuildValidator for CallValidator { let arguments_schema: &PyAny = schema.get_as_req(intern!(py, "arguments_schema"))?; let arguments_validator = Box::new(build_validator(arguments_schema, config, definitions)?); - let return_schema = schema.get_item(intern!(py, "return_schema")); + let return_schema = schema.get_item(intern!(py, "return_schema"))?; let return_validator = match return_schema { Some(return_schema) => Some(Box::new(build_validator(return_schema, config, definitions)?)), None => None, diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 2706cadef..b18faea2c 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -56,7 +56,7 @@ impl BuildValidator for DataclassArgsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -73,7 +73,7 @@ impl BuildValidator for DataclassArgsValidator { let py_name: &PyString = field.get_as_req(intern!(py, "name"))?; let name: String = py_name.extract()?; - let lookup_key = match field.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(name.as_str()) } else { None }; LookupKey::from_py(py, alias, alt_alias)? @@ -584,7 +584,7 @@ impl Validator for DataclassValidator { if self.slots { let value = dc_dict - .get_item(field_name) + .get_item(field_name)? .ok_or_else(|| PyKeyError::new_err(field_name.to_string()))?; force_setattr(py, obj, field_name, value)?; } else { diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index baf4ca467..5f1fc8bef 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -256,7 +256,7 @@ impl TZConstraint { pub(super) fn from_py(schema: &PyDict) -> PyResult> { let py = schema.py(); - let tz_constraint = match schema.get_item(intern!(py, "tz_constraint")) { + let tz_constraint = match schema.get_item(intern!(py, "tz_constraint"))? { Some(c) => c, None => return Ok(None), }; diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 211befe07..be19d1eda 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -1,7 +1,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyDict, PyTuple, PyType}; -use pyo3::{intern, AsPyPointer}; use pyo3::{prelude::*, PyTypeInfo}; use crate::build_tools::{is_strict, schema_or_config_same}; diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 250145290..5026afba3 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -35,11 +35,11 @@ impl BuildValidator for DictValidator { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let key_validator = match schema.get_item(intern!(py, "keys_schema")) { + let key_validator = match schema.get_item(intern!(py, "keys_schema"))? { Some(schema) => Box::new(build_validator(schema, config, definitions)?), None => Box::new(AnyValidator::build(schema, config, definitions)?), }; - let value_validator = match schema.get_item(intern!(py, "values_schema")) { + let value_validator = match schema.get_item(intern!(py, "values_schema"))? { Some(d) => Box::new(build_validator(d, config, definitions)?), None => Box::new(AnyValidator::build(schema, config, definitions)?), }; diff --git a/src/validators/float.rs b/src/validators/float.rs index 2e9434d9f..1d62d2006 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -19,11 +19,11 @@ impl BuildValidator for FloatBuilder { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "multiple_of")).is_some() - || schema.get_item(intern!(py, "le")).is_some() - || schema.get_item(intern!(py, "lt")).is_some() - || schema.get_item(intern!(py, "ge")).is_some() - || schema.get_item(intern!(py, "gt")).is_some(); + let use_constrained = schema.get_item(intern!(py, "multiple_of"))?.is_some() + || schema.get_item(intern!(py, "le"))?.is_some() + || schema.get_item(intern!(py, "lt"))?.is_some() + || schema.get_item(intern!(py, "ge"))?.is_some() + || schema.get_item(intern!(py, "gt"))?.is_some(); if use_constrained { ConstrainedFloatValidator::build(schema, config, definitions) } else { diff --git a/src/validators/int.rs b/src/validators/int.rs index 0903e1998..1a807d1ef 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -25,11 +25,11 @@ impl BuildValidator for IntValidator { _definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let use_constrained = schema.get_item(intern!(py, "multiple_of")).is_some() - || schema.get_item(intern!(py, "le")).is_some() - || schema.get_item(intern!(py, "lt")).is_some() - || schema.get_item(intern!(py, "ge")).is_some() - || schema.get_item(intern!(py, "gt")).is_some(); + let use_constrained = schema.get_item(intern!(py, "multiple_of"))?.is_some() + || schema.get_item(intern!(py, "le"))?.is_some() + || schema.get_item(intern!(py, "lt"))?.is_some() + || schema.get_item(intern!(py, "ge"))?.is_some() + || schema.get_item(intern!(py, "gt"))?.is_some(); if use_constrained { ConstrainedIntValidator::build(schema, config) } else { diff --git a/src/validators/list.rs b/src/validators/list.rs index 8e931657f..c0af5fcb5 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -23,7 +23,7 @@ pub fn get_items_schema( config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult> { - match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + match schema.get_item(pyo3::intern!(schema.py(), "items_schema"))? { Some(d) => { let validator = build_validator(d, config, definitions)?; match validator { diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 25cb94bd9..19bbe91f5 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -126,7 +126,7 @@ impl LiteralLookup { } // must be an enum or bytes if let Some(expected_py) = &self.expected_py { - if let Some(v) = expected_py.as_ref(py).get_item(input) { + if let Some(v) = expected_py.as_ref(py).get_item(input)? { let id: usize = v.extract().unwrap(); return Ok(Some((input, &self.values[id]))); } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5db192f3a..2a4cbf165 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -130,7 +130,7 @@ impl SchemaValidator { _ => None, }; let config_title = match config { - Some(c) => c.get_item("title"), + Some(c) => c.get_item("title")?, None => None, }; let title = match config_title { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 774d3eef9..b79145c97 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -58,7 +58,7 @@ impl BuildValidator for ModelFieldsValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -81,7 +81,7 @@ impl BuildValidator for ModelFieldsValidator { Err(err) => return py_schema_err!("Field \"{}\":\n {}", field_name, err), }; - let lookup_key = match field_info.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field_info.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(field_name) } else { None }; LookupKey::from_py(py, alias, alt_alias)? diff --git a/src/validators/set.rs b/src/validators/set.rs index 626572139..9270ea204 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -25,7 +25,7 @@ macro_rules! set_build { definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema")) { + let item_validator = match schema.get_item(pyo3::intern!(schema.py(), "items_schema"))? { Some(d) => Box::new(crate::validators::build_validator(d, config, definitions)?), None => Box::new(crate::validators::any::AnyValidator::build( schema, diff --git a/src/validators/string.rs b/src/validators/string.rs index 4eab4602a..0be51ece8 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -189,10 +189,10 @@ impl StrConstrainedValidator { let to_upper: bool = schema_or_config(schema, config, intern!(py, "to_upper"), intern!(py, "str_to_upper"))?.unwrap_or(false); - let coerce_numbers_to_str = config - .and_then(|c| c.get_item("coerce_numbers_to_str")) - .and_then(|v| v.is_true().ok()) - .unwrap_or(false); + let coerce_numbers_to_str = match config { + Some(c) => c.get_item("coerce_numbers_to_str")?.map_or(Ok(false), PyAny::is_true)?, + None => false, + }; Ok(Self { strict: is_strict(schema, config)?, diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index 21340f2f0..a58d98f7a 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -25,7 +25,7 @@ struct TimedeltaConstraints { } fn get_constraint(schema: &PyDict, key: &str) -> PyResult> { - match schema.get_item(key) { + match schema.get_item(key)? { Some(value) => { let either_timedelta = EitherTimedelta::try_from(value)?; Ok(Some(either_timedelta.to_duration()?)) diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index cfa239e3b..1b7cf9b00 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -113,7 +113,7 @@ impl BuildValidator for TuplePositionalValidator { Ok(Self { strict: is_strict(schema, config)?, items_validators: validators, - extras_validator: match schema.get_item(intern!(py, "extras_schema")) { + extras_validator: match schema.get_item(intern!(py, "extras_schema"))? { Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), None => None, }, diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 118d992a2..dab492da7 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -61,7 +61,7 @@ impl BuildValidator for TypedDictValidator { let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; - let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) { + let extras_validator = match (schema.get_item(intern!(py, "extras_schema"))?, &extra_behavior) { (Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)), (Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"), (_, _) => None, @@ -109,7 +109,7 @@ impl BuildValidator for TypedDictValidator { } } - let lookup_key = match field_info.get_item(intern!(py, "validation_alias")) { + let lookup_key = match field_info.get_item(intern!(py, "validation_alias"))? { Some(alias) => { let alt_alias = if populate_by_name { Some(field_name) } else { None }; LookupKey::from_py(py, alias, alt_alias)? diff --git a/src/validators/union.rs b/src/validators/union.rs index 79a21e78d..a8bd29d7d 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -518,7 +518,7 @@ impl TaggedUnionValidator { ) -> ValResult<'data, &'data PyString> { let dict = input.strict_dict()?; let either_tag = match dict { - GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type")) { + GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type"))? { Some(t) => t.strict_str()?, None => return Err(self.tag_not_found(input)), }, @@ -529,7 +529,7 @@ impl TaggedUnionValidator { // custom logic to distinguish between different function and tuple schemas if tag == "function" || tag == "tuple" { let mode = match dict { - GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode")) { + GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode"))? { Some(m) => Some(m.strict_str()?), None => None, }, diff --git a/tests/test.rs b/tests/test.rs index 9b2fb99b5..348520435 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -75,8 +75,8 @@ a = A() "#; let locals = PyDict::new(py); py.run(code, None, Some(locals)).unwrap(); - let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap(); - let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap(); + let a: &PyAny = locals.get_item("a").unwrap().unwrap().extract().unwrap(); + let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); let serialized: Vec = SchemaSerializer::py_new(py, schema, None) .unwrap() .to_json(py, a, None, None, None, true, false, false, false, false, true, None) From 005c8a752d6335d78a1f8cda5179cbbd3f44992e Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Sun, 15 Oct 2023 09:56:38 +0200 Subject: [PATCH 070/550] feat: add `ser_json_bytes` mode `'hex'` (#1016) --- python/pydantic_core/core_schema.py | 2 +- src/serializers/config.rs | 13 ++++++++++++- tests/serializers/test_bytes.py | 7 +++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index adb8e5647..40d1eec77 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -101,7 +101,7 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: bool # default: True # the config options are used to customise serialization to JSON ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601' - ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8' + ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' # used to hide input data from ValidationError repr hide_input_in_errors: bool validation_error_cause: bool # default: False diff --git a/src/serializers/config.rs b/src/serializers/config.rs index 4f3611129..fb65623b7 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -129,6 +129,7 @@ pub(crate) enum BytesMode { #[default] Utf8, Base64, + Hex, } impl FromStr for BytesMode { @@ -138,7 +139,11 @@ impl FromStr for BytesMode { match s { "utf8" => Ok(Self::Utf8), "base64" => Ok(Self::Base64), - s => py_schema_err!("Invalid bytes serialization mode: `{}`, expected `utf8` or `base64`", s), + "hex" => Ok(Self::Hex), + s => py_schema_err!( + "Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`", + s + ), } } } @@ -158,6 +163,9 @@ impl BytesMode { .map_err(|err| utf8_py_error(py, err, bytes)) .map(Cow::Borrowed), Self::Base64 => Ok(Cow::Owned(base64::engine::general_purpose::URL_SAFE.encode(bytes))), + Self::Hex => Ok(Cow::Owned( + bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")), + )), } } @@ -168,6 +176,9 @@ impl BytesMode { Err(e) => Err(Error::custom(e.to_string())), }, Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)), + Self::Hex => { + serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}"))) + } } } } diff --git a/tests/serializers/test_bytes.py b/tests/serializers/test_bytes.py index e313138c9..13849bed0 100644 --- a/tests/serializers/test_bytes.py +++ b/tests/serializers/test_bytes.py @@ -105,6 +105,13 @@ def test_bytes_base64(): assert base64.b64decode(s.to_python(b'foo bar', mode='json').encode()) == b'foo bar' +def test_bytes_hex(): + s = SchemaSerializer(core_schema.bytes_schema(), {'ser_json_bytes': 'hex'}) + assert s.to_python(b'\xff\xff') == b'\xff\xff' + assert s.to_json(b'\xff\xff') == b'"ffff"' + assert s.to_python(b'\xff\xff', mode='json') == 'ffff' == b'\xff\xff'.hex() + + def test_bytes_base64_dict_key(): s = SchemaSerializer(core_schema.dict_schema(core_schema.bytes_schema()), {'ser_json_bytes': 'base64'}) From e866c1111181c8b4729636fb88ed2ad3b515718f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:01:01 +0100 Subject: [PATCH 071/550] Bump strum_macros from 0.25.2 to 0.25.3 (#1021) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c2af0c54..9e19810d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -460,9 +460,9 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 4a6541781..3afde2c50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ include = [ pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } regex = "1.9.6" strum = { version = "0.25.0", features = ["derive"] } -strum_macros = "0.25.2" +strum_macros = "0.25.3" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.188", features = ["derive"] } From 085342ba351ddd7ef3452ff6fa8de05224a87531 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:01:16 +0100 Subject: [PATCH 072/550] Bump serde from 1.0.188 to 1.0.189 (#1019) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e19810d3..350a03e14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -403,18 +403,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3afde2c50..184f6ba0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.188", features = ["derive"] } +serde = { version = "1.0.189", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.1" ahash = "0.8.0" From d0704489f45045f3699f1d949e5dd96bb5ccf4c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 13:02:09 +0100 Subject: [PATCH 073/550] Bump regex from 1.9.6 to 1.10.1 (#1020) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 350a03e14..91e0a13c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,9 +356,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.6" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +checksum = "aaac441002f822bc9705a681810a4dd2963094b9ca0ddc41cb963a4c189189ea" dependencies = [ "aho-corasick", "memchr", @@ -368,9 +368,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.9" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +checksum = "5011c7e263a695dc8ca064cddb722af1be54e517a280b12a5356f98366899e5d" dependencies = [ "aho-corasick", "memchr", @@ -379,9 +379,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "rustversion" diff --git a/Cargo.toml b/Cargo.toml index 184f6ba0a..1cb996450 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } -regex = "1.9.6" +regex = "1.10.1" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} From b697abfbb2ab7381996e49697e5355b96ca7184b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 12:03:25 +0000 Subject: [PATCH 074/550] Bump pyright from 1.1.330.post0 to 1.1.331 (#1023) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 25b74ce08..48a7e4a04 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.4 -pyright==1.1.330.post0 +pyright==1.1.331 ruff==0.0.292 mypy==1.5.1 From 1b7b6e90979451f89f709a0cc940846828da7d68 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 12:12:05 +0000 Subject: [PATCH 075/550] Bump mypy from 1.5.1 to 1.6.0 (#1022) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 48a7e4a04..7008082ff 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -2,4 +2,4 @@ black==23.9.1 griffe==0.36.4 pyright==1.1.331 ruff==0.0.292 -mypy==1.5.1 +mypy==1.6.0 From f8c0920c977f2755c4c95381beac059b4142ec0e Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 16 Oct 2023 07:21:05 -0500 Subject: [PATCH 076/550] Fix bug allowing validation of `bool` types with `coerce_numbers_to_str=True` (#1017) --- src/input/input_python.rs | 3 ++- tests/validators/test_string.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index cd8ecf267..fe688a487 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -246,10 +246,11 @@ impl<'a> Input<'a> for PyAny { Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), }; Ok(s.into()) - } else if coerce_numbers_to_str && { + } else if coerce_numbers_to_str && !PyBool::is_exact_type_of(self) && { let py = self.py(); let decimal_type: Py = get_decimal_type(py); + // only allow int, float, and decimal (not bool) self.is_instance_of::() || self.is_instance_of::() || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index acb145a58..cab6e5127 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -277,6 +277,16 @@ def test_coerce_numbers_to_str_disabled_in_strict_mode() -> None: v.validate_json('42') +def test_coerce_numbers_to_str_raises_for_bool() -> None: + config = core_schema.CoreConfig(coerce_numbers_to_str=True) + + v = SchemaValidator(core_schema.str_schema(), config) + with pytest.raises(ValidationError): + v.validate_python(True) + with pytest.raises(ValidationError): + v.validate_json(False) + + @pytest.mark.parametrize( ('number', 'expected_str'), [ From 38c9e64e61bdca5e487cee7ed8399a9e43e189bd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 12:22:02 +0000 Subject: [PATCH 077/550] Bump griffe from 0.36.4 to 0.36.5 (#1025) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 7008082ff..bcc66c749 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 -griffe==0.36.4 +griffe==0.36.5 pyright==1.1.331 ruff==0.0.292 mypy==1.6.0 From f8470df75b62c061caf5c5c09e26c5d2ac94aa8d Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Tue, 17 Oct 2023 19:20:28 +0100 Subject: [PATCH 078/550] Bump version to 2.11.0 (#1027) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 91e0a13c7..95d0363bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,7 +242,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.10.1" +version = "2.11.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 1cb996450..ed3ea71ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.10.1" +version = "2.11.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 8f347ff780a358f3c5efdf0593cc32b786b14b53 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:03:42 +0100 Subject: [PATCH 079/550] Bump regex from 1.10.1 to 1.10.2 (#1036) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 95d0363bc..7e11a04cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,9 +356,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.1" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaac441002f822bc9705a681810a4dd2963094b9ca0ddc41cb963a4c189189ea" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", @@ -368,9 +368,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5011c7e263a695dc8ca064cddb722af1be54e517a280b12a5356f98366899e5d" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", diff --git a/Cargo.toml b/Cargo.toml index ed3ea71ed..a5309acc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } -regex = "1.10.1" +regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} From a3c367ec4fbe9242cb44a411d0051bc696279ae3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:04:16 +0100 Subject: [PATCH 080/550] Bump uuid from 1.4.1 to 1.5.0 (#1035) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e11a04cf..22f882816 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -543,9 +543,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" [[package]] name = "version_check" diff --git a/Cargo.toml b/Cargo.toml index a5309acc4..03c6bda1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ idna = "0.4.0" base64 = "0.21.4" num-bigint = "0.4.4" python3-dll-a = "0.2.7" -uuid = "1.4.1" +uuid = "1.5.0" [lib] name = "_pydantic_core" From ad1d384568a0847bd9577ec7e3cc7e4fd4d5373f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:04:38 +0100 Subject: [PATCH 081/550] Bump ahash from 0.8.3 to 0.8.4 (#1034) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 33 +++++++++++++++++++++++++++------ Cargo.toml | 2 +- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 22f882816..ed5454a87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,14 +4,15 @@ version = 3 [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "72832d73be48bac96a5d7944568f305d829ed55b0ce3b483647089dfaf6cf704" dependencies = [ "cfg-if", "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] @@ -233,9 +234,9 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "proc-macro2" -version = "1.0.64" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] @@ -473,9 +474,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.28" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", @@ -615,3 +616,23 @@ name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "zerocopy" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c19fae0c8a9efc6a8281f2e623db8af1db9e57852e04cde3e754dd2dc29340f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc56589e9ddd1f1c28d4b4b5c773ce232910a6bb67a70133d61c9e347585efe9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 03c6bda1b..83b774ecd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ enum_dispatch = "0.3.8" serde = { version = "1.0.189", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.1" -ahash = "0.8.0" +ahash = "0.8.4" url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" From 53619bf8c60564f80db428cef2f2a06cbbd05408 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:05:03 +0100 Subject: [PATCH 082/550] Bump base64 from 0.21.4 to 0.21.5 (#1033) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed5454a87..7c1ef6624 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,9 +32,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.4" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "bitflags" diff --git a/Cargo.toml b/Cargo.toml index 83b774ecd..ced489442 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ ahash = "0.8.4" url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" -base64 = "0.21.4" +base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.5.0" From 81924c0a864fb1ef9c7bc3c795a429c861640c5f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:05:16 +0100 Subject: [PATCH 083/550] Bump ruff from 0.0.292 to 0.1.1 (#1032) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index bcc66c749..21040cbb5 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.5 pyright==1.1.331 -ruff==0.0.292 +ruff==0.1.1 mypy==1.6.0 From 45e8df311be4082eec5de7cc57174e47b86dcaa6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:05:34 +0100 Subject: [PATCH 084/550] Bump griffe from 0.36.5 to 0.36.7 (#1031) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 21040cbb5..8fde2c0a7 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 -griffe==0.36.5 +griffe==0.36.7 pyright==1.1.331 ruff==0.1.1 mypy==1.6.0 From f59bcc7b16d17c9ff19a074c5494ffbf86d90047 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:13:21 +0000 Subject: [PATCH 085/550] Bump pyright from 1.1.331 to 1.1.332 (#1029) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 8fde2c0a7..8531d72dd 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.9.1 griffe==0.36.7 -pyright==1.1.331 +pyright==1.1.332 ruff==0.1.1 mypy==1.6.0 From 22255ad209cc879488e9ea4e3fd17b3d4a8feb72 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:14:16 +0000 Subject: [PATCH 086/550] Bump mypy from 1.6.0 to 1.6.1 (#1030) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 8531d72dd..2f5b20469 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -2,4 +2,4 @@ black==23.9.1 griffe==0.36.7 pyright==1.1.332 ruff==0.1.1 -mypy==1.6.0 +mypy==1.6.1 From acf15bf30205c1ae72689051055b0d253b46a505 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 13:22:23 +0000 Subject: [PATCH 087/550] Bump black from 23.9.1 to 23.10.0 (#1028) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 2f5b20469..91e454abf 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -black==23.9.1 +black==23.10.0 griffe==0.36.7 pyright==1.1.332 ruff==0.1.1 From 23d106551d14214f24a0bcd577b74a672ccfd0c0 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 25 Oct 2023 09:52:26 +0100 Subject: [PATCH 088/550] Don't accept NaN in float and decimal constraints (#1037) --- src/validators/decimal.rs | 19 +++++++++++++++---- src/validators/float.rs | 10 ++++++---- tests/validators/test_decimal.py | 4 ++++ tests/validators/test_float.py | 4 ++++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index be19d1eda..730eeac69 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -182,8 +182,19 @@ impl Validator for DecimalValidator { } } + // Decimal raises DecimalOperation when comparing NaN, so if it's necessary to compare + // the value to a number, we need to check for NaN first. We cache the result on the first + // time we check it. + let mut is_nan: Option = None; + let mut is_nan = || -> PyResult { + match is_nan { + Some(is_nan) => Ok(is_nan), + None => Ok(*is_nan.insert(decimal.call_method0(intern!(py, "is_nan"))?.extract()?)), + } + }; + if let Some(le) = &self.le { - if !decimal.le(le)? { + if is_nan()? || !decimal.le(le)? { return Err(ValError::new( ErrorType::LessThanEqual { le: Number::String(le.to_string()), @@ -194,7 +205,7 @@ impl Validator for DecimalValidator { } } if let Some(lt) = &self.lt { - if !decimal.lt(lt)? { + if is_nan()? || !decimal.lt(lt)? { return Err(ValError::new( ErrorType::LessThan { lt: Number::String(lt.to_string()), @@ -205,7 +216,7 @@ impl Validator for DecimalValidator { } } if let Some(ge) = &self.ge { - if !decimal.ge(ge)? { + if is_nan()? || !decimal.ge(ge)? { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: Number::String(ge.to_string()), @@ -216,7 +227,7 @@ impl Validator for DecimalValidator { } } if let Some(gt) = &self.gt { - if !decimal.gt(gt)? { + if is_nan()? || !decimal.gt(gt)? { return Err(ValError::new( ErrorType::GreaterThan { gt: Number::String(gt.to_string()), diff --git a/src/validators/float.rs b/src/validators/float.rs index 1d62d2006..646d8f4d8 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -129,7 +131,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(le) = self.le { - if float > le { + if !matches!(float.partial_cmp(&le), Some(Ordering::Less | Ordering::Equal)) { return Err(ValError::new( ErrorType::LessThanEqual { le: le.into(), @@ -140,7 +142,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(lt) = self.lt { - if float >= lt { + if !matches!(float.partial_cmp(<), Some(Ordering::Less)) { return Err(ValError::new( ErrorType::LessThan { lt: lt.into(), @@ -151,7 +153,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(ge) = self.ge { - if float < ge { + if !matches!(float.partial_cmp(&ge), Some(Ordering::Greater | Ordering::Equal)) { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: ge.into(), @@ -162,7 +164,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(gt) = self.gt { - if float <= gt { + if !matches!(float.partial_cmp(>), Some(Ordering::Greater)) { return Err(ValError::new( ErrorType::GreaterThan { gt: gt.into(), diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 376a9816a..cd54c89ae 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -148,12 +148,16 @@ def test_decimal_strict_json(input_value, expected): ({'le': 0}, 0, Decimal(0)), ({'le': 0}, -1, Decimal(-1)), ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), ], ) def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'decimal', **kwargs}) + if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): + expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 74f0024ca..b18181fbb 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -86,10 +86,14 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected): ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), float('inf')), ], ) def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'float', **kwargs}) + if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): + expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) From 866eb2da06321cf4c99118b2331bf7eff7e81e1c Mon Sep 17 00:00:00 2001 From: Michael Huang Date: Thu, 26 Oct 2023 12:50:35 -0400 Subject: [PATCH 089/550] Add lax_str and lax_int support for enum values not inherited from str/int (#1015) --- src/input/input_python.rs | 21 ++++++++++++++++++++- src/input/shared.rs | 16 +++++++++++++++- src/serializers/ob_type.rs | 3 ++- tests/validators/test_int.py | 13 +++++++++++++ tests/validators/test_string.py | 15 +++++++++++++++ 5 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/input/input_python.rs b/src/input/input_python.rs index fe688a487..33d7ca296 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -21,7 +21,10 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::shared::{ + decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float, + str_as_int, +}; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, @@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny { || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() } { Ok(self.str()?.into()) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(enum_val.str()?.into()) } else { Err(ValError::new(ErrorTypeDefaults::StringType, self)) } @@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny { decimal_as_int(self.py(), self, decimal) } else if let Ok(float) = self.extract::() { float_as_int(self, float) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(EitherInt::Py(enum_val)) } else { Err(ValError::new(ErrorTypeDefaults::IntType, self)) } @@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult Option<&PyAny> { + let py = v.py(); + let enum_meta_object = get_enum_meta_object(py); + let meta_type = v.get_type().get_type(); + if meta_type.is(&enum_meta_object) { + v.getattr(intern!(py, "value")).ok() + } else { + None + } +} + #[cfg(PyPy)] static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell> = pyo3::once_cell::GILOnceCell::new(); diff --git a/src/input/shared.rs b/src/input/shared.rs index 1a8e2b61c..105da4bcc 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,11 +1,25 @@ use num_bigint::BigInt; -use pyo3::{intern, PyAny, Python}; +use pyo3::sync::GILOnceCell; +use pyo3::{intern, Py, PyAny, Python, ToPyObject}; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use super::parse_json::{JsonArray, JsonInput}; use super::{EitherFloat, EitherInt, Input}; +static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); + +pub fn get_enum_meta_object(py: Python) -> Py { + ENUM_META_OBJECT + .get_or_init(py, || { + py.import(intern!(py, "enum")) + .and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta"))) + .unwrap() + .to_object(py) + }) + .clone() +} + pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { ValError::new( ErrorType::JsonInvalid { diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index fc491f618..109aed3bb 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -259,8 +259,9 @@ impl ObTypeLookup { fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool { // only test on the type itself, not base types if op_value.is_some() { + let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type(); let meta_type = py_type.get_type(); - meta_type.is(&self.enum_object) + meta_type.is(enum_meta_type) } else { false } diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 8d5850dc8..dedc2bd93 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -459,3 +459,16 @@ def test_float_subclass() -> None: v_lax = v.validate_python(FloatSubclass(1)) assert v_lax == 1 assert type(v_lax) == int + + +def test_int_subclass_plain_enum() -> None: + v = SchemaValidator({'type': 'int'}) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 1 + + v_lax = v.validate_python(PlainEnum.ONE) + assert v_lax == 1 + assert type(v_lax) == int diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index cab6e5127..bc2102de2 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -249,6 +249,21 @@ def test_lax_subclass(FruitEnum, kwargs): assert repr(p) == "'pear'" +@pytest.mark.parametrize('kwargs', [{}, {'to_lower': True}], ids=repr) +def test_lax_subclass_plain_enum(kwargs): + v = SchemaValidator(core_schema.str_schema(**kwargs)) + + from enum import Enum + + class PlainEnum(Enum): + ONE = 'one' + + p = v.validate_python(PlainEnum.ONE) + assert p == 'one' + assert type(p) is str + assert repr(p) == "'one'" + + def test_subclass_preserved() -> None: class StrSubclass(str): pass From dd75669bbb7a624e714afa66d3c7f8e24f339a7c Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:34:36 -0500 Subject: [PATCH 090/550] Support subclasses in lists in `Union` of `List` types (#1039) --- src/serializers/type_serializers/list.rs | 4 ++ src/serializers/type_serializers/union.rs | 3 +- tests/serializers/test_union.py | 56 +++++++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/src/serializers/type_serializers/list.rs b/src/serializers/type_serializers/list.rs index 4d68ae373..a71e3452a 100644 --- a/src/serializers/type_serializers/list.rs +++ b/src/serializers/type_serializers/list.rs @@ -116,4 +116,8 @@ impl TypeSerializer for ListSerializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + self.item_serializer.retry_with_lax_check() + } } diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 788620408..f05e2220e 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -75,9 +75,10 @@ impl TypeSerializer for UnionSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - // try the serializers in with error_on fallback=true + // try the serializers in left to right order with error_on fallback=true let mut new_extra = extra.clone(); new_extra.check = SerCheck::Strict; + for comb_serializer in &self.choices { match comb_serializer.to_python(value, include, exclude, &new_extra) { Ok(v) => return Ok(v), diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index f81e33a6b..9b021e66e 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -454,3 +454,59 @@ class Item(BaseModel): ) assert s.to_python(DBUser(name='John', password='secret')) == {'name': 'John'} + + +def test_union_serializes_list_of_model_subclass_from_definition() -> None: + class BaseModel: + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' + + def __init__(self, **kwargs: Any): + for key, value in kwargs.items(): + setattr(self, key, value) + + class User(BaseModel): + name: str + + class DBUser(User): + password: str + __pydantic_serializer__: ClassVar[SchemaSerializer] + + DBUser.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema( + DBUser, + core_schema.model_fields_schema( + { + 'name': core_schema.model_field(core_schema.str_schema()), + 'password': core_schema.model_field(core_schema.str_schema()), + } + ), + ) + ) + + class Item(BaseModel): + price: float + + s = SchemaSerializer( + core_schema.definitions_schema( + core_schema.union_schema( + [ + core_schema.list_schema(core_schema.definition_reference_schema('User'), strict=False), + core_schema.list_schema(core_schema.definition_reference_schema('Item'), strict=False), + ] + ), + [ + core_schema.model_schema( + User, + core_schema.model_fields_schema({'name': core_schema.model_field(core_schema.str_schema())}), + ref='User', + ), + core_schema.model_schema( + Item, + core_schema.model_fields_schema({'price': core_schema.model_field(core_schema.float_schema())}), + ref='Item', + ), + ], + ) + ) + + assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}] From ef3e8132e51df7189985c8a62c6bc2cb73879a13 Mon Sep 17 00:00:00 2001 From: sydney-runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 31 Oct 2023 09:26:59 -0500 Subject: [PATCH 091/550] Allow validation against `max_digits` and `decimals` to pass if normalized or non-normalized input is valid (#1049) --- src/validators/decimal.rs | 129 ++++++++++++++++++------------- tests/validators/test_decimal.py | 28 +++++++ 2 files changed, 104 insertions(+), 53 deletions(-) diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index 730eeac69..eb3141c31 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator { gt }); +fn extract_decimal_digits_info<'data>( + decimal: &PyAny, + normalized: bool, + py: Python<'data>, +) -> ValResult<'data, (u64, u64)> { + let mut normalized_decimal: Option<&PyAny> = None; + if normalized { + normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal)); + } + let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal + .unwrap_or(decimal) + .call_method0(intern!(py, "as_tuple"))? + .extract()?; + + // finite values have numeric exponent, we checked is_finite above + let exponent: i64 = exponent.extract()?; + let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; + let decimals; + if exponent >= 0 { + // A positive exponent adds that many trailing zeros. + digits += exponent as u64; + decimals = 0; + } else { + // If the absolute value of the negative exponent is larger than the + // number of digits, then it's the same as the number of digits, + // because it'll consume all the digits in digit_tuple and then + // add abs(exponent) - len(digit_tuple) leading zeros after the + // decimal point. + decimals = exponent.unsigned_abs(); + digits = digits.max(decimals); + } + + Ok((decimals, digits)) +} + impl Validator for DecimalValidator { fn validate<'data>( &self, @@ -98,65 +133,53 @@ impl Validator for DecimalValidator { } if self.check_digits { - let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal); - let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = - normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?; + if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) { + if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) { + if let Some(max_digits) = self.max_digits { + if (digits > max_digits) & (normalized_digits > max_digits) { + return Err(ValError::new( + ErrorType::DecimalMaxDigits { + max_digits, + context: None, + }, + input, + )); + } + } - // finite values have numeric exponent, we checked is_finite above - let exponent: i64 = exponent.extract()?; - let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?; - let decimals; - if exponent >= 0 { - // A positive exponent adds that many trailing zeros. - digits += exponent as u64; - decimals = 0; - } else { - // If the absolute value of the negative exponent is larger than the - // number of digits, then it's the same as the number of digits, - // because it'll consume all the digits in digit_tuple and then - // add abs(exponent) - len(digit_tuple) leading zeros after the - // decimal point. - decimals = exponent.unsigned_abs(); - digits = digits.max(decimals); - } + if let Some(decimal_places) = self.decimal_places { + if (decimals > decimal_places) & (normalized_decimals > decimal_places) { + return Err(ValError::new( + ErrorType::DecimalMaxPlaces { + decimal_places, + context: None, + }, + input, + )); + } - if let Some(max_digits) = self.max_digits { - if digits > max_digits { - return Err(ValError::new( - ErrorType::DecimalMaxDigits { - max_digits, - context: None, - }, - input, - )); - } - } + if let Some(max_digits) = self.max_digits { + let whole_digits = digits.saturating_sub(decimals); + let max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(decimal_places) = self.decimal_places { - if decimals > decimal_places { - return Err(ValError::new( - ErrorType::DecimalMaxPlaces { - decimal_places, - context: None, - }, - input, - )); - } + let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals); + let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places); - if let Some(max_digits) = self.max_digits { - let whole_digits = digits.saturating_sub(decimals); - let max_whole_digits = max_digits.saturating_sub(decimal_places); - if whole_digits > max_whole_digits { - return Err(ValError::new( - ErrorType::DecimalWholeDigits { - whole_digits: max_whole_digits, - context: None, - }, - input, - )); + if (whole_digits > max_whole_digits) + & (normalized_whole_digits > normalized_max_whole_digits) + { + return Err(ValError::new( + ErrorType::DecimalWholeDigits { + whole_digits: max_whole_digits, + context: None, + }, + input, + )); + } + } } } - } + }; } } diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index cd54c89ae..43b3d19b9 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -437,3 +437,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec def test_validate_scientific_notation_from_json(input_value, expected): v = SchemaValidator({'type': 'decimal'}) assert v.validate_json(input_value) == expected + + +def test_validate_max_digits_and_decimal_places() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2}) + + # valid inputs + assert v.validate_json('1.23') == Decimal('1.23') + assert v.validate_json('123.45') == Decimal('123.45') + assert v.validate_json('-123.45') == Decimal('-123.45') + + # invalid inputs + with pytest.raises(ValidationError): + v.validate_json('1234.56') # too many digits + with pytest.raises(ValidationError): + v.validate_json('123.456') # too many decimal places + with pytest.raises(ValidationError): + v.validate_json('123456') # too many digits + with pytest.raises(ValidationError): + v.validate_json('abc') # not a valid decimal + + +def test_validate_max_digits_and_decimal_places_edge_case() -> None: + v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18}) + + # valid inputs + assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal( + '9999999999999999.999999999999999999' + ) From 383535655507ac3eeb4e3f3755f5e239bf021c28 Mon Sep 17 00:00:00 2001 From: Iipin <52832022+Iipin@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:22:13 +0100 Subject: [PATCH 092/550] Fix proper pluralization in validation error messages (#1050) --- src/errors/types.rs | 60 +++++++++++++++++++++++++++++++------------- tests/test_errors.py | 12 +++++++++ 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/src/errors/types.rs b/src/errors/types.rs index e31307e2b..5c3fc1a7c 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -445,8 +445,8 @@ macro_rules! to_string_render { }; } -fn plural_s(value: usize) -> &'static str { - if value == 1 { +fn plural_s + PartialEq>(value: T) -> &'static str { + if value == 1.into() { "" } else { "s" @@ -494,8 +494,8 @@ impl ErrorType { Self::StringType {..} => "Input should be a valid string", Self::StringSubType {..} => "Input should be a string, not an instance of a subclass of str", Self::StringUnicode {..} => "Input should be a valid string, unable to parse raw data as a unicode string", - Self::StringTooShort {..} => "String should have at least {min_length} characters", - Self::StringTooLong {..} => "String should have at most {max_length} characters", + Self::StringTooShort {..} => "String should have at least {min_length} character{expected_plural}", + Self::StringTooLong {..} => "String should have at most {max_length} character{expected_plural}", Self::StringPatternMismatch {..} => "String should match pattern '{pattern}'", Self::Enum {..} => "Input should be {expected}", Self::DictType {..} => "Input should be a valid dictionary", @@ -512,8 +512,8 @@ impl ErrorType { Self::FloatType {..} => "Input should be a valid number", Self::FloatParsing {..} => "Input should be a valid number, unable to parse string as a number", Self::BytesType {..} => "Input should be a valid bytes", - Self::BytesTooShort {..} => "Data should have at least {min_length} bytes", - Self::BytesTooLong {..} => "Data should have at most {max_length} bytes", + Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}", + Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}", Self::ValueError {..} => "Value error, {error}", Self::AssertionError {..} => "Assertion failed, {error}", Self::CustomError {..} => "", // custom errors are handled separately @@ -552,16 +552,16 @@ impl ErrorType { Self::UrlType {..} => "URL input should be a string or URL", Self::UrlParsing {..} => "Input should be a valid URL, {error}", Self::UrlSyntaxViolation {..} => "Input violated strict URL syntax rules, {error}", - Self::UrlTooLong {..} => "URL should have at most {max_length} characters", + Self::UrlTooLong {..} => "URL should have at most {max_length} character{expected_plural}", Self::UrlScheme {..} => "URL scheme should be {expected_schemes}", Self::UuidType {..} => "UUID input should be a string, bytes or UUID object", Self::UuidParsing {..} => "Input should be a valid UUID, {error}", Self::UuidVersion {..} => "UUID version {expected_version} expected", Self::DecimalType {..} => "Decimal input should be an integer, float, string or Decimal object", Self::DecimalParsing {..} => "Input should be a valid decimal", - Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digits in total", - Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal places", - Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digits before the decimal point", + Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total", + Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}", + Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point", } } @@ -643,13 +643,25 @@ impl ErrorType { to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,) } Self::IterationError { error, .. } => render!(tmpl, error), - Self::StringTooShort { min_length, .. } => to_string_render!(tmpl, min_length), - Self::StringTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::StringTooShort { min_length, .. } => { + let expected_plural = plural_s(*min_length); + to_string_render!(tmpl, min_length, expected_plural) + } + Self::StringTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::StringPatternMismatch { pattern, .. } => render!(tmpl, pattern), Self::Enum { expected, .. } => to_string_render!(tmpl, expected), Self::MappingType { error, .. } => render!(tmpl, error), - Self::BytesTooShort { min_length, .. } => to_string_render!(tmpl, min_length), - Self::BytesTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::BytesTooShort { min_length, .. } => { + let expected_plural = plural_s(*min_length); + to_string_render!(tmpl, min_length, expected_plural) + } + Self::BytesTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::ValueError { error, .. } => { let error = &error .as_ref() @@ -688,13 +700,25 @@ impl ErrorType { Self::UnionTagNotFound { discriminator, .. } => render!(tmpl, discriminator), Self::UrlParsing { error, .. } => render!(tmpl, error), Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error), - Self::UrlTooLong { max_length, .. } => to_string_render!(tmpl, max_length), + Self::UrlTooLong { max_length, .. } => { + let expected_plural = plural_s(*max_length); + to_string_render!(tmpl, max_length, expected_plural) + } Self::UrlScheme { expected_schemes, .. } => render!(tmpl, expected_schemes), Self::UuidParsing { error, .. } => render!(tmpl, error), Self::UuidVersion { expected_version, .. } => to_string_render!(tmpl, expected_version), - Self::DecimalMaxDigits { max_digits, .. } => to_string_render!(tmpl, max_digits), - Self::DecimalMaxPlaces { decimal_places, .. } => to_string_render!(tmpl, decimal_places), - Self::DecimalWholeDigits { whole_digits, .. } => to_string_render!(tmpl, whole_digits), + Self::DecimalMaxDigits { max_digits, .. } => { + let expected_plural = plural_s(*max_digits); + to_string_render!(tmpl, max_digits, expected_plural) + } + Self::DecimalMaxPlaces { decimal_places, .. } => { + let expected_plural = plural_s(*decimal_places); + to_string_render!(tmpl, decimal_places, expected_plural) + } + Self::DecimalWholeDigits { whole_digits, .. } => { + let expected_plural = plural_s(*whole_digits); + to_string_render!(tmpl, whole_digits, expected_plural) + } _ => Ok(tmpl.to_string()), } } diff --git a/tests/test_errors.py b/tests/test_errors.py index 293880977..3683150ec 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -289,7 +289,9 @@ def f(input_value, info): ('string_unicode', 'Input should be a valid string, unable to parse raw data as a unicode string', None), ('string_pattern_mismatch', "String should match pattern 'foo'", {'pattern': 'foo'}), ('string_too_short', 'String should have at least 42 characters', {'min_length': 42}), + ('string_too_short', 'String should have at least 1 character', {'min_length': 1}), ('string_too_long', 'String should have at most 42 characters', {'max_length': 42}), + ('string_too_long', 'String should have at most 1 character', {'max_length': 1}), ('dict_type', 'Input should be a valid dictionary', None), ('mapping_type', 'Input should be a valid mapping, error: foobar', {'error': 'foobar'}), ('iterable_type', 'Input should be iterable', None), @@ -312,7 +314,9 @@ def f(input_value, info): ('float_parsing', 'Input should be a valid number, unable to parse string as a number', None), ('bytes_type', 'Input should be a valid bytes', None), ('bytes_too_short', 'Data should have at least 42 bytes', {'min_length': 42}), + ('bytes_too_short', 'Data should have at least 1 byte', {'min_length': 1}), ('bytes_too_long', 'Data should have at most 42 bytes', {'max_length': 42}), + ('bytes_too_long', 'Data should have at most 1 byte', {'max_length': 1}), ('value_error', 'Value error, foobar', {'error': ValueError('foobar')}), ('assertion_error', 'Assertion failed, foobar', {'error': AssertionError('foobar')}), ('literal_error', 'Input should be foo', {'expected': 'foo'}), @@ -356,6 +360,7 @@ def f(input_value, info): ('url_parsing', 'Input should be a valid URL, Foobar', {'error': 'Foobar'}), ('url_syntax_violation', 'Input violated strict URL syntax rules, Foobar', {'error': 'Foobar'}), ('url_too_long', 'URL should have at most 42 characters', {'max_length': 42}), + ('url_too_long', 'URL should have at most 1 character', {'max_length': 1}), ('url_scheme', 'URL scheme should be "foo", "bar" or "spam"', {'expected_schemes': '"foo", "bar" or "spam"'}), ('uuid_type', 'UUID input should be a string, bytes or UUID object', None), ('uuid_parsing', 'Input should be a valid UUID, Foobar', {'error': 'Foobar'}), @@ -363,12 +368,19 @@ def f(input_value, info): ('decimal_type', 'Decimal input should be an integer, float, string or Decimal object', None), ('decimal_parsing', 'Input should be a valid decimal', None), ('decimal_max_digits', 'Decimal input should have no more than 42 digits in total', {'max_digits': 42}), + ('decimal_max_digits', 'Decimal input should have no more than 1 digit in total', {'max_digits': 1}), ('decimal_max_places', 'Decimal input should have no more than 42 decimal places', {'decimal_places': 42}), + ('decimal_max_places', 'Decimal input should have no more than 1 decimal place', {'decimal_places': 1}), ( 'decimal_whole_digits', 'Decimal input should have no more than 42 digits before the decimal point', {'whole_digits': 42}, ), + ( + 'decimal_whole_digits', + 'Decimal input should have no more than 1 digit before the decimal point', + {'whole_digits': 1}, + ), ] From 73c431b2c766e3c75541a2bb153286e0e6e87d61 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:36:46 +0000 Subject: [PATCH 093/550] Bump serde from 1.0.189 to 1.0.190 (#1047) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c1ef6624..78a717da3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -404,18 +404,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.189" +version = "1.0.190" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" +checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.189" +version = "1.0.190" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" +checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index ced489442..c4f79d32d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.189", features = ["derive"] } +serde = { version = "1.0.190", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.1" ahash = "0.8.4" From 8dba67ab8ed213ebf213a856e40ec52d5653ca0d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:37:08 +0000 Subject: [PATCH 094/550] Bump ahash from 0.8.4 to 0.8.6 (#1046) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78a717da3..14804774d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72832d73be48bac96a5d7944568f305d829ed55b0ce3b483647089dfaf6cf704" +checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" dependencies = [ "cfg-if", "getrandom", @@ -619,18 +619,18 @@ checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] name = "zerocopy" -version = "0.7.11" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c19fae0c8a9efc6a8281f2e623db8af1db9e57852e04cde3e754dd2dc29340f" +checksum = "dd66a62464e3ffd4e37bd09950c2b9dd6c4f8767380fabba0d523f9a775bc85a" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.11" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc56589e9ddd1f1c28d4b4b5c773ce232910a6bb67a70133d61c9e347585efe9" +checksum = "255c4596d41e6916ced49cfafea18727b24d67878fa180ddfd69b9df34fd1726" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c4f79d32d..cdd1b7056 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ enum_dispatch = "0.3.8" serde = { version = "1.0.190", features = ["derive"] } speedate = "0.12.0" smallvec = "1.11.1" -ahash = "0.8.4" +ahash = "0.8.6" url = "2.4.1" # idna is already required by url, added here to be explicit idna = "0.4.0" From d98f14ff12669c5be6d0278b4b5ecd0876a2176a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:37:31 +0000 Subject: [PATCH 095/550] Bump actions/setup-node from 3 to 4 (#1045) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 85174fa41..62e998bda 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -236,7 +236,7 @@ jobs: python-version: '3.11' # used to lint js code - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: node-version: '18' @@ -318,7 +318,7 @@ jobs: - name: build wheels run: make build-wasm - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: node-version: '18' From fdaf9da2090a9e909757caabe72dbee87ac2b115 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:38:05 +0000 Subject: [PATCH 096/550] Bump pytest from 7.4.2 to 7.4.3 (#1044) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 98cef770f..66dda075f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,7 +3,7 @@ dirty-equals==0.6.0 hypothesis==6.79.4 # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.2 +pytest==7.4.3 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-examples==0.0.10 From 8484cec113c3c7a0e0b14d184211989205bea62f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:38:16 +0000 Subject: [PATCH 097/550] Bump ruff from 0.1.1 to 0.1.3 (#1043) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 91e454abf..e12ef0928 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.10.0 griffe==0.36.7 pyright==1.1.332 -ruff==0.1.1 +ruff==0.1.3 mypy==1.6.1 From acf3361478499e4341644d66094011f6283a49cf Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 09:38:27 +0000 Subject: [PATCH 098/550] Bump griffe from 0.36.7 to 0.36.9 (#1041) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index e12ef0928..f0c468109 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.10.0 -griffe==0.36.7 +griffe==0.36.9 pyright==1.1.332 ruff==0.1.3 mypy==1.6.1 From 9d07a8c6b6285a16192dc99e2d9ade3d3cf2214d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 10:38:13 +0000 Subject: [PATCH 099/550] Bump pyright from 1.1.332 to 1.1.334 (#1055) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index f0c468109..79a41f7ea 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,5 @@ black==23.10.0 griffe==0.36.9 -pyright==1.1.332 +pyright==1.1.334 ruff==0.1.3 mypy==1.6.1 From 1f18da208721d993d344934adc4bcc5bbfadb17f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 6 Nov 2023 11:32:53 +0000 Subject: [PATCH 100/550] jiter (#974) --- Cargo.lock | 85 +++++++++ Cargo.toml | 2 + python/pydantic_core/__init__.py | 2 + python/pydantic_core/_pydantic_core.pyi | 18 ++ src/errors/line_error.rs | 6 +- src/errors/location.rs | 28 +-- src/errors/mod.rs | 2 +- src/input/input_abstract.rs | 12 +- src/input/input_json.rs | 167 +++++++++--------- src/input/input_python.rs | 59 ++++--- src/input/input_string.rs | 13 +- src/input/mod.rs | 2 - src/input/parse_json.rs | 222 ------------------------ src/input/return_enums.rs | 16 +- src/input/shared.rs | 12 +- src/lazy_index_map.rs | 63 ------- src/lib.rs | 17 +- src/lookup_key.rs | 18 +- src/validators/arguments.rs | 2 +- src/validators/dataclass.rs | 2 +- src/validators/dict.rs | 2 +- src/validators/function.rs | 22 +-- src/validators/model_fields.rs | 2 +- src/validators/typed_dict.rs | 2 +- src/validators/union.rs | 4 +- tests/test_json.py | 9 +- tests/validators/test_decimal.py | 7 +- tests/validators/test_float.py | 37 +++- tests/validators/test_function.py | 15 +- tests/validators/test_int.py | 26 ++- 30 files changed, 385 insertions(+), 489 deletions(-) delete mode 100644 src/input/parse_json.rs delete mode 100644 src/lazy_index_map.rs diff --git a/Cargo.lock b/Cargo.lock index 14804774d..f16b93697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,84 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" +[[package]] +name = "jiter" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b27d419c535bf7b50ad355278b1159cbf0cc8d507ea003d625b17bf0375720b8" +dependencies = [ + "ahash", + "lexical-core", + "num-bigint", + "num-traits", + "pyo3", + "smallvec", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + [[package]] name = "libc" version = "0.2.147" @@ -249,6 +327,7 @@ dependencies = [ "base64", "enum_dispatch", "idna", + "jiter", "num-bigint", "pyo3", "pyo3-build-config", @@ -450,6 +529,12 @@ dependencies = [ "strum_macros", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "strum" version = "0.25.0" diff --git a/Cargo.toml b/Cargo.toml index cdd1b7056..211fdfe01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,8 @@ base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.5.0" +jiter = {version = "0.0.4", features = ["python"]} +#jiter = {path = "../jiter", features = ["python"]} [lib] name = "_pydantic_core" diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index a46a77b7d..5b2655c91 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -22,6 +22,7 @@ Url, ValidationError, __version__, + from_json, to_json, to_jsonable_python, validate_core_schema, @@ -63,6 +64,7 @@ 'PydanticSerializationUnexpectedValue', 'TzInfo', 'to_json', + 'from_json', 'to_jsonable_python', 'validate_core_schema', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 8ed3092a9..f28b7a12a 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -41,6 +41,7 @@ __all__ = [ 'PydanticUndefinedType', 'Some', 'to_json', + 'from_json', 'to_jsonable_python', 'list_all_errors', 'TzInfo', @@ -384,6 +385,23 @@ def to_json( JSON bytes. """ +def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any: + """ + Deserialize JSON data to a Python object. + + This is effectively a faster version of [`json.loads()`][json.loads]. + + Arguments: + data: The JSON data to deserialize. + allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. + + Raises: + ValueError: If deserialization fails. + + Returns: + The deserialized Python object. + """ + def to_jsonable_python( value: Any, *, diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index e5d3c7bac..3ee4c7894 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -2,7 +2,9 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::PyDowncastError; -use crate::input::{Input, JsonInput}; +use jiter::JsonValue; + +use crate::input::Input; use super::location::{LocItem, Location}; use super::types::ErrorType; @@ -147,7 +149,7 @@ impl<'a> ValLineError<'a> { #[derive(Clone)] pub enum InputValue<'a> { PyAny(&'a PyAny), - JsonInput(JsonInput), + JsonInput(JsonValue), String(&'a str), } diff --git a/src/errors/location.rs b/src/errors/location.rs index e5c32d5e2..8acc2a039 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -3,12 +3,11 @@ use pyo3::once_cell::GILOnceCell; use std::fmt; use pyo3::prelude::*; -use pyo3::types::{PyList, PyString, PyTuple}; +use pyo3::types::{PyList, PyTuple}; use serde::ser::SerializeSeq; use serde::{Serialize, Serializer}; use crate::lookup_key::{LookupPath, PathItem}; -use crate::tools::extract_i64; /// Used to store individual items of the error location, e.g. a string for key/field names /// or a number for array indices. @@ -35,6 +34,12 @@ impl fmt::Display for LocItem { } } +// TODO rename to ToLocItem +pub trait AsLocItem { + // TODO rename to to_loc_item + fn as_loc_item(&self) -> LocItem; +} + impl From for LocItem { fn from(s: String) -> Self { Self::S(s) @@ -82,21 +87,6 @@ impl ToPyObject for LocItem { } } -impl TryFrom<&PyAny> for LocItem { - type Error = PyErr; - - fn try_from(loc_item: &PyAny) -> PyResult { - if let Ok(py_str) = loc_item.downcast::() { - let str = py_str.to_str()?.to_string(); - Ok(Self::S(str)) - } else if let Ok(int) = extract_i64(loc_item) { - Ok(Self::I(int)) - } else { - Err(PyTypeError::new_err("Item in a location must be a string or int")) - } - } -} - impl Serialize for LocItem { fn serialize(&self, serializer: S) -> Result where @@ -211,9 +201,9 @@ impl TryFrom> for Location { fn try_from(location: Option<&PyAny>) -> PyResult { if let Some(location) = location { let mut loc_vec: Vec = if let Ok(tuple) = location.downcast::() { - tuple.iter().map(LocItem::try_from).collect::>()? + tuple.iter().map(AsLocItem::as_loc_item).collect() } else if let Ok(list) = location.downcast::() { - list.iter().map(LocItem::try_from).collect::>()? + list.iter().map(AsLocItem::as_loc_item).collect() } else { return Err(PyTypeError::new_err( "Location must be a list or tuple of strings and ints", diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 6a253197f..bfc5b4329 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -7,7 +7,7 @@ mod validation_exception; mod value_exception; pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; -pub use self::location::LocItem; +pub use self::location::{AsLocItem, LocItem}; pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 655ba24b9..52551ef42 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,13 +4,15 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; -use crate::errors::{InputValue, LocItem, ValResult}; +use jiter::JsonValue; + +use crate::errors::{AsLocItem, InputValue, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherInt, EitherString}; -use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput}; +use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping}; #[derive(Debug, Clone, Copy)] pub enum InputType { @@ -46,9 +48,7 @@ impl TryFrom<&str> for InputType { /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same -pub trait Input<'a>: fmt::Debug + ToPyObject { - fn as_loc_item(&self) -> LocItem; - +pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem { fn as_error_value(&'a self) -> InputValue<'a>; fn identity(&self) -> Option { @@ -89,7 +89,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>; - fn parse_json(&'a self) -> ValResult<'a, JsonInput>; + fn parse_json(&'a self) -> ValResult<'a, JsonValue>; fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult> { if strict { diff --git a/src/input/input_json.rs b/src/input/input_json.rs index e375f5755..ac552621d 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,46 +1,48 @@ use std::borrow::Cow; +use jiter::{JsonArray, JsonValue}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int, string_to_vec}; +use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonArgs, JsonArray, JsonInput, + GenericIterator, GenericMapping, Input, JsonArgs, }; -impl<'a> Input<'a> for JsonInput { - /// This is required by since JSON object keys are always strings, I don't think it can be called - #[cfg_attr(has_coverage_attribute, coverage(off))] +/// This is required but since JSON object keys are always strings, I don't think it can be called +impl AsLocItem for JsonValue { fn as_loc_item(&self) -> LocItem { match self { - JsonInput::Int(i) => (*i).into(), - JsonInput::String(s) => s.as_str().into(), + JsonValue::Int(i) => (*i).into(), + JsonValue::Str(s) => s.as_str().into(), v => format!("{v:?}").into(), } } +} +impl<'a> Input<'a> for JsonValue { fn as_error_value(&'a self) -> InputValue<'a> { - // cloning JsonInput is cheap due to use of Arc + // cloning JsonValue is cheap due to use of Arc InputValue::JsonInput(self.clone()) } fn is_none(&self) -> bool { - matches!(self, JsonInput::Null) + matches!(self, JsonValue::Null) } fn as_kwargs(&'a self, py: Python<'a>) -> Option<&'a PyDict> { match self { - JsonInput::Object(object) => { + JsonValue::Object(object) => { let dict = PyDict::new(py); for (k, v) in object.iter() { dict.set_item(k, v.to_object(py)).unwrap(); @@ -53,15 +55,15 @@ impl<'a> Input<'a> for JsonInput { fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), - JsonInput::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), _ => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), } } fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { match self { - JsonInput::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), + JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), _ => { let class_name = class_name.to_string(); Err(ValError::new( @@ -75,33 +77,32 @@ impl<'a> Input<'a> for JsonInput { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { - JsonInput::String(s) => serde_json::from_str(s.as_str()).map_err(|e| map_json_err(self, e)), + JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)), _ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } } fn strict_str(&'a self) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_str().into()), + JsonValue::Str(s) => Ok(s.as_str().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_str().into()), - JsonInput::BigInt(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Float(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Int(v) if coerce_numbers_to_str => Ok(v.to_string().into()), - JsonInput::Uint(v) if coerce_numbers_to_str => Ok(v.to_string().into()), + JsonValue::Str(s) => Ok(s.as_str().into()), + JsonValue::Int(i) if coerce_numbers_to_str => Ok(i.to_string().into()), + JsonValue::BigInt(b) if coerce_numbers_to_str => Ok(b.to_string().into()), + JsonValue::Float(f) if coerce_numbers_to_str => Ok(f.to_string().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } fn validate_bytes(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::String(s) => Ok(s.as_bytes().into()), + JsonValue::Str(s) => Ok(s.as_bytes().into()), _ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } @@ -112,16 +113,16 @@ impl<'a> Input<'a> for JsonInput { fn strict_bool(&self) -> ValResult { match self { - JsonInput::Bool(b) => Ok(*b), + JsonValue::Bool(b) => Ok(*b), _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } fn lax_bool(&self) -> ValResult { match self { - JsonInput::Bool(b) => Ok(*b), - JsonInput::String(s) => str_as_bool(self, s), - JsonInput::Int(int) => int_as_bool(self, *int), - JsonInput::Float(float) => match float_as_int(self, *float) { + JsonValue::Bool(b) => Ok(*b), + JsonValue::Str(s) => str_as_bool(self, s), + JsonValue::Int(int) => int_as_bool(self, *int), + JsonValue::Float(float) => match float_as_int(self, *float) { Ok(int) => int .as_bool() .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), @@ -133,60 +134,56 @@ impl<'a> Input<'a> for JsonInput { fn strict_int(&'a self) -> ValResult> { match self { - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), + JsonValue::Int(i) => Ok(EitherInt::I64(*i)), + JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } fn lax_int(&'a self) -> ValResult> { match self { - JsonInput::Bool(b) => match *b { + JsonValue::Bool(b) => match *b { true => Ok(EitherInt::I64(1)), false => Ok(EitherInt::I64(0)), }, - JsonInput::Int(i) => Ok(EitherInt::I64(*i)), - JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), - JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - JsonInput::Float(f) => float_as_int(self, *f), - JsonInput::String(str) => str_as_int(self, str), + JsonValue::Int(i) => Ok(EitherInt::I64(*i)), + JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), + JsonValue::Float(f) => float_as_int(self, *f), + JsonValue::Str(str) => str_as_int(self, str), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } fn ultra_strict_float(&'a self) -> ValResult> { match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn strict_float(&'a self) -> ValResult> { match self { - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn lax_float(&'a self) -> ValResult> { match self { - JsonInput::Bool(b) => match *b { + JsonValue::Bool(b) => match *b { true => Ok(EitherFloat::F64(1.0)), false => Ok(EitherFloat::F64(0.0)), }, - JsonInput::Float(f) => Ok(EitherFloat::F64(*f)), - JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)), - JsonInput::String(str) => str_as_float(self, str), + JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), + JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), + JsonValue::Str(str) => str_as_float(self, str), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { match self { - JsonInput::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), + JsonValue::Float(f) => create_decimal(PyString::new(py, &f.to_string()), self, py), - JsonInput::String(..) | JsonInput::Int(..) | JsonInput::Uint(..) | JsonInput::BigInt(..) => { + JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => { create_decimal(self.to_object(py).into_ref(py), self, py) } _ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)), @@ -195,7 +192,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_dict(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Object(dict) => Ok(dict.into()), + JsonValue::Object(dict) => Ok(dict.into()), _ => Err(ValError::new(ErrorTypeDefaults::DictType, self)), } } @@ -206,7 +203,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_list(&'a self, _strict: bool) -> ValResult> { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::ListType, self)), } } @@ -218,7 +215,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_tuple(&'a self, _strict: bool) -> ValResult> { // just as in set's case, List has to be allowed match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::TupleType, self)), } } @@ -230,7 +227,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_set(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a set from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::SetType, self)), } } @@ -242,7 +239,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_frozenset(&'a self, _strict: bool) -> ValResult> { // we allow a list here since otherwise it would be impossible to create a frozenset from JSON match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), _ => Err(ValError::new(ErrorTypeDefaults::FrozenSetType, self)), } } @@ -253,20 +250,20 @@ impl<'a> Input<'a> for JsonInput { fn extract_generic_iterable(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(GenericIterable::JsonArray(a)), - JsonInput::String(s) => Ok(GenericIterable::JsonString(s)), - JsonInput::Object(object) => Ok(GenericIterable::JsonObject(object)), + JsonValue::Array(a) => Ok(GenericIterable::JsonArray(a)), + JsonValue::Str(s) => Ok(GenericIterable::JsonString(s)), + JsonValue::Object(object) => Ok(GenericIterable::JsonObject(object)), _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), } } fn validate_iter(&self) -> ValResult { match self { - JsonInput::Array(a) => Ok(a.clone().into()), - JsonInput::String(s) => Ok(string_to_vec(s).into()), - JsonInput::Object(object) => { + JsonValue::Array(a) => Ok(a.clone().into()), + JsonValue::Str(s) => Ok(string_to_vec(s).into()), + JsonValue::Object(object) => { // return keys iterator to match python's behavior - let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonInput::String(k.clone())).collect()); + let keys: JsonArray = JsonArray::new(object.keys().map(|k| JsonValue::Str(k.clone())).collect()); Ok(keys.into()) } _ => Err(ValError::new(ErrorTypeDefaults::IterableType, self)), @@ -275,7 +272,7 @@ impl<'a> Input<'a> for JsonInput { fn validate_date(&self, _strict: bool) -> ValResult { match self { - JsonInput::String(v) => bytes_as_date(self, v.as_bytes()), + JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()), _ => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } @@ -291,16 +288,16 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), } } fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { match self { - JsonInput::String(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_time(self, *v, 0), - JsonInput::Float(v) => float_as_time(self, *v), - JsonInput::BigInt(_) => Err(ValError::new( + JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => int_as_time(self, *v, 0), + JsonValue::Float(v) => float_as_time(self, *v), + JsonValue::BigInt(_) => Err(ValError::new( ErrorType::TimeParsing { error: Cow::Borrowed( speedate::ParseError::TimeTooLarge @@ -320,7 +317,7 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } @@ -329,9 +326,9 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => int_as_datetime(self, *v, 0), - JsonInput::Float(v) => float_as_datetime(self, *v), + JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => int_as_datetime(self, *v, 0), + JsonValue::Float(v) => float_as_datetime(self, *v), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } @@ -341,7 +338,7 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } @@ -350,29 +347,31 @@ impl<'a> Input<'a> for JsonInput { microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, ) -> ValResult { match self { - JsonInput::String(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - JsonInput::Int(v) => Ok(int_as_duration(self, *v)?.into()), - JsonInput::Float(v) => Ok(float_as_duration(self, *v)?.into()), + JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), + JsonValue::Int(v) => Ok(int_as_duration(self, *v)?.into()), + JsonValue::Float(v) => Ok(float_as_duration(self, *v)?.into()), _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } } -impl BorrowInput for &'_ JsonInput { - type Input<'a> = JsonInput where Self: 'a; +impl BorrowInput for &'_ JsonValue { + type Input<'a> = JsonValue where Self: 'a; fn borrow_input(&self) -> &Self::Input<'_> { self } } -/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this -/// implementation -/// Required for Dict keys so the string can behave like an Input -impl<'a> Input<'a> for String { +impl AsLocItem for String { fn as_loc_item(&self) -> LocItem { self.to_string().into() } +} +/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this +/// implementation +/// Required for JSON Object keys so the string can behave like an Input +impl<'a> Input<'a> for String { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::String(self) } @@ -398,8 +397,8 @@ impl<'a> Input<'a> for String { )) } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - serde_json::from_str(self.as_str()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) } fn strict_str(&'a self) -> ValResult> { @@ -504,3 +503,7 @@ impl BorrowInput for String { self } } + +fn string_to_vec(s: &str) -> JsonArray { + JsonArray::new(s.chars().map(|c| JsonValue::Str(c.to_string())).collect()) +} diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 33d7ca296..de59ebce0 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -9,9 +9,11 @@ use pyo3::types::{ #[cfg(not(PyPy))] use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use pyo3::{intern, PyTypeInfo}; + +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; @@ -27,7 +29,7 @@ use super::shared::{ }; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, - GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs, + GenericIterable, GenericIterator, GenericMapping, Input, PyArgs, }; #[cfg(not(PyPy))] @@ -90,7 +92,7 @@ macro_rules! extract_dict_items { }; } -impl<'a> Input<'a> for PyAny { +impl AsLocItem for PyAny { fn as_loc_item(&self) -> LocItem { if let Ok(py_str) = self.downcast::() { py_str.to_string_lossy().as_ref().into() @@ -100,7 +102,9 @@ impl<'a> Input<'a> for PyAny { safe_repr(self).to_string().into() } } +} +impl<'a> Input<'a> for PyAny { fn as_error_value(&'a self) -> InputValue<'a> { InputValue::PyAny(self) } @@ -183,19 +187,20 @@ impl<'a> Input<'a> for PyAny { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { - if let Ok(py_bytes) = self.downcast::() { - serde_json::from_slice(py_bytes.as_bytes()).map_err(|e| map_json_err(self, e)) + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + let bytes = if let Ok(py_bytes) = self.downcast::() { + py_bytes.as_bytes() } else if let Ok(py_str) = self.downcast::() { let str = py_string_str(py_str)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + str.as_bytes() } else if let Ok(py_byte_array) = self.downcast::() { // Safety: from_slice does not run arbitrary Python code and the GIL is held so the - // bytes array will not be mutated while from_slice is reading it - serde_json::from_slice(unsafe { py_byte_array.as_bytes() }).map_err(|e| map_json_err(self, e)) + // bytes array will not be mutated while `JsonValue::parse` is reading it + unsafe { py_byte_array.as_bytes() } } else { - Err(ValError::new(ErrorTypeDefaults::JsonType, self)) - } + return Err(ValError::new(ErrorTypeDefaults::JsonType, self)); + }; + JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e)) } fn strict_str(&'a self) -> ValResult> { @@ -210,22 +215,6 @@ impl<'a> Input<'a> for PyAny { } } - fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { - Ok(EitherString::Py(py_str)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - - fn exact_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { if let Ok(py_str) = ::try_from_exact(self) { Ok(py_str.into()) @@ -352,6 +341,22 @@ impl<'a> Input<'a> for PyAny { } } + fn exact_int(&'a self) -> ValResult> { + if PyInt::is_exact_type_of(self) { + Ok(EitherInt::Py(self)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + + fn exact_str(&'a self) -> ValResult> { + if let Ok(py_str) = PyString::try_from_exact(self) { + Ok(EitherString::Py(py_str)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + fn ultra_strict_float(&'a self) -> ValResult> { if self.is_instance_of::() { Err(ValError::new(ErrorTypeDefaults::FloatType, self)) diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 72a32d897..b84908edf 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -1,9 +1,10 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::input::py_string_str; use crate::tools::safe_repr; use crate::validators::decimal::create_decimal; @@ -14,7 +15,7 @@ use super::datetime::{ use super::shared::{map_json_err, str_as_bool, str_as_float}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, JsonInput, + GenericIterator, GenericMapping, Input, }; #[derive(Debug)] @@ -52,14 +53,16 @@ impl<'py> StringMapping<'py> { } } -impl<'a> Input<'a> for StringMapping<'a> { +impl AsLocItem for StringMapping<'_> { fn as_loc_item(&self) -> LocItem { match self { Self::String(s) => s.to_string_lossy().as_ref().into(), Self::Mapping(d) => safe_repr(d).to_string().into(), } } +} +impl<'a> Input<'a> for StringMapping<'a> { fn as_error_value(&'a self) -> InputValue<'a> { match self { Self::String(s) => s.as_error_value(), @@ -83,11 +86,11 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn parse_json(&'a self) -> ValResult<'a, JsonInput> { + fn parse_json(&'a self) -> ValResult<'a, JsonValue> { match self { Self::String(s) => { let str = py_string_str(s)?; - serde_json::from_str(str).map_err(|e| map_json_err(self, e)) + JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e)) } Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), } diff --git a/src/input/mod.rs b/src/input/mod.rs index 22d774a8c..13c835f83 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -7,7 +7,6 @@ mod input_abstract; mod input_json; mod input_python; mod input_string; -mod parse_json; mod return_enums; mod shared; @@ -18,7 +17,6 @@ pub(crate) use datetime::{ }; pub(crate) use input_abstract::{BorrowInput, Input, InputType}; pub(crate) use input_string::StringMapping; -pub(crate) use parse_json::{JsonArray, JsonInput, JsonObject}; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs deleted file mode 100644 index 20a107669..000000000 --- a/src/input/parse_json.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::fmt; -use std::sync::Arc; - -use num_bigint::BigInt; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; -use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; -use smallvec::SmallVec; - -use crate::lazy_index_map::LazyIndexMap; - -/// similar to serde `Value` but with int and float split -#[derive(Clone, Debug)] -pub enum JsonInput { - Null, - Bool(bool), - Int(i64), - BigInt(BigInt), - Uint(u64), - Float(f64), - String(String), - Array(JsonArray), - Object(JsonObject), -} -pub type JsonArray = Arc>; -pub type JsonObject = Arc>; - -impl ToPyObject for JsonInput { - fn to_object(&self, py: Python<'_>) -> PyObject { - match self { - Self::Null => py.None(), - Self::Bool(b) => b.into_py(py), - Self::Int(i) => i.into_py(py), - Self::BigInt(b) => b.to_object(py), - Self::Uint(i) => i.into_py(py), - Self::Float(f) => f.into_py(py), - Self::String(s) => s.into_py(py), - Self::Array(v) => PyList::new(py, v.iter().map(|v| v.to_object(py))).into_py(py), - Self::Object(o) => { - let dict = PyDict::new(py); - for (k, v) in o.iter() { - dict.set_item(k, v.to_object(py)).unwrap(); - } - dict.into_py(py) - } - } - } -} - -impl<'de> Deserialize<'de> for JsonInput { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct JsonVisitor; - - impl<'de> Visitor<'de> for JsonVisitor { - type Value = JsonInput; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("any valid JSON value") - } - - fn visit_bool(self, value: bool) -> Result { - Ok(JsonInput::Bool(value)) - } - - fn visit_i64(self, value: i64) -> Result { - Ok(JsonInput::Int(value)) - } - - fn visit_u64(self, value: u64) -> Result { - match i64::try_from(value) { - Ok(i) => Ok(JsonInput::Int(i)), - Err(_) => Ok(JsonInput::Uint(value)), - } - } - - fn visit_f64(self, value: f64) -> Result { - Ok(JsonInput::Float(value)) - } - - fn visit_str(self, value: &str) -> Result - where - E: SerdeError, - { - Ok(JsonInput::String(value.to_string())) - } - - fn visit_string(self, value: String) -> Result { - Ok(JsonInput::String(value)) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_none(self) -> Result { - unreachable!() - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_some(self, _: D) -> Result - where - D: serde::Deserializer<'de>, - { - unreachable!() - } - - fn visit_unit(self) -> Result { - Ok(JsonInput::Null) - } - - fn visit_seq(self, mut visitor: V) -> Result - where - V: SeqAccess<'de>, - { - let mut vec = SmallVec::new(); - - while let Some(elem) = visitor.next_element()? { - vec.push(elem); - } - - Ok(JsonInput::Array(JsonArray::new(vec))) - } - - fn visit_map(self, mut visitor: V) -> Result - where - V: MapAccess<'de>, - { - const SERDE_JSON_NUMBER: &str = "$serde_json::private::Number"; - match visitor.next_key_seed(KeyDeserializer)? { - Some(first_key) => { - let mut values = LazyIndexMap::new(); - let first_value = visitor.next_value()?; - - // serde_json will parse arbitrary precision numbers into a map - // structure with a "number" key and a String value - 'try_number: { - if first_key == SERDE_JSON_NUMBER { - // Just in case someone tries to actually store that key in a real map, - // keep parsing and continue as a map if so - - if let Some((key, value)) = visitor.next_entry::()? { - // Important to preserve order of the keys - values.insert(first_key, first_value); - values.insert(key, value); - break 'try_number; - } - - if let JsonInput::String(s) = &first_value { - // Normalize the string to either an int or float - let normalized = if s.chars().any(|c| c == '.' || c == 'E' || c == 'e') { - JsonInput::Float( - s.parse() - .map_err(|e| V::Error::custom(format!("expected a float: {e}")))?, - ) - } else if let Ok(i) = s.parse::() { - JsonInput::Int(i) - } else if let Ok(big) = s.parse::() { - JsonInput::BigInt(big) - } else { - // Failed to normalize, just throw it in the map and continue - values.insert(first_key, first_value); - break 'try_number; - }; - - return Ok(normalized); - }; - } else { - values.insert(first_key, first_value); - } - } - - while let Some((key, value)) = visitor.next_entry()? { - values.insert(key, value); - } - Ok(JsonInput::Object(Arc::new(values))) - } - None => Ok(JsonInput::Object(Arc::new(LazyIndexMap::new()))), - } - } - } - - deserializer.deserialize_any(JsonVisitor) - } -} - -struct KeyDeserializer; - -impl<'de> DeserializeSeed<'de> for KeyDeserializer { - type Value = String; - - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_str(self) - } -} - -impl<'de> Visitor<'de> for KeyDeserializer { - type Value = String; - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string key") - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - Ok(s.to_string()) - } - - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn visit_string(self, _: String) -> Result - where - E: serde::de::Error, - { - unreachable!() - } -} diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index daa9f39fe..412842a13 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -4,6 +4,7 @@ use std::ops::Rem; use std::slice::Iter as SliceIter; use std::str::FromStr; +use jiter::{JsonArray, JsonObject, JsonValue}; use num_bigint::BigInt; use pyo3::exceptions::PyTypeError; @@ -26,7 +27,6 @@ use crate::tools::py_err; use crate::validators::{CombinedValidator, ValidationState, Validator}; use super::input_string::StringMapping; -use super::parse_json::{JsonArray, JsonInput, JsonObject}; use super::{py_error_on_minusone, Input}; /// Container for all the collections (sized iterable containers) types, which @@ -50,7 +50,7 @@ pub enum GenericIterable<'a> { PyByteArray(&'a PyByteArray), Sequence(&'a PySequence), Iterator(&'a PyIterator), - JsonArray(&'a [JsonInput]), + JsonArray(&'a [JsonValue]), JsonObject(&'a JsonObject), JsonString(&'a String), } @@ -573,7 +573,7 @@ impl<'py> Iterator for AttributesGenericIterator<'py> { } pub struct JsonObjectGenericIterator<'py> { - object_iter: SliceIter<'py, (String, JsonInput)>, + object_iter: SliceIter<'py, (String, JsonValue)>, } impl<'py> JsonObjectGenericIterator<'py> { @@ -585,7 +585,7 @@ impl<'py> JsonObjectGenericIterator<'py> { } impl<'py> Iterator for JsonObjectGenericIterator<'py> { - type Item = ValResult<'py, (&'py String, &'py JsonInput)>; + type Item = ValResult<'py, (&'py String, &'py JsonValue)>; fn next(&mut self) -> Option { self.object_iter.next().map(|(key, value)| Ok((key, value))) @@ -653,7 +653,7 @@ pub struct GenericJsonIterator { } impl GenericJsonIterator { - pub fn next(&mut self, _py: Python) -> PyResult> { + pub fn next(&mut self, _py: Python) -> PyResult> { if self.index < self.array.len() { // panic here is impossible due to bounds check above; compiler should be // able to optimize it away even @@ -667,7 +667,7 @@ impl GenericJsonIterator { } pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { - InputValue::JsonInput(JsonInput::Array(self.array.clone())) + InputValue::JsonInput(JsonValue::Array(self.array.clone())) } pub fn index(&self) -> usize { @@ -689,12 +689,12 @@ impl<'a> PyArgs<'a> { #[cfg_attr(debug_assertions, derive(Debug))] pub struct JsonArgs<'a> { - pub args: Option<&'a [JsonInput]>, + pub args: Option<&'a [JsonValue]>, pub kwargs: Option<&'a JsonObject>, } impl<'a> JsonArgs<'a> { - pub fn new(args: Option<&'a [JsonInput]>, kwargs: Option<&'a JsonObject>) -> Self { + pub fn new(args: Option<&'a [JsonValue]>, kwargs: Option<&'a JsonObject>) -> Self { Self { args, kwargs } } } diff --git a/src/input/shared.rs b/src/input/shared.rs index 105da4bcc..718210098 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,12 +1,12 @@ -use num_bigint::BigInt; use pyo3::sync::GILOnceCell; use pyo3::{intern, Py, PyAny, Python, ToPyObject}; +use jiter::JsonValueError; +use num_bigint::BigInt; + use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; -use super::parse_json::{JsonArray, JsonInput}; use super::{EitherFloat, EitherInt, Input}; - static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); pub fn get_enum_meta_object(py: Python) -> Py { @@ -20,7 +20,7 @@ pub fn get_enum_meta_object(py: Python) -> Py { .clone() } -pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> { +pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> { ValError::new( ErrorType::JsonInvalid { error: error.to_string(), @@ -164,7 +164,3 @@ pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a Py } Ok(EitherInt::Py(numerator)) } - -pub fn string_to_vec(s: &str) -> JsonArray { - JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect()) -} diff --git a/src/lazy_index_map.rs b/src/lazy_index_map.rs deleted file mode 100644 index c5621f877..000000000 --- a/src/lazy_index_map.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::borrow::Borrow; -use std::cmp::{Eq, PartialEq}; -use std::fmt::Debug; -use std::hash::Hash; -use std::slice::Iter as SliceIter; -use std::sync::OnceLock; - -use ahash::AHashMap; -use smallvec::SmallVec; - -#[derive(Debug, Clone, Default)] -pub struct LazyIndexMap { - vec: SmallVec<[(K, V); 8]>, - map: OnceLock>, -} - -/// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. -impl LazyIndexMap -where - K: Clone + Debug + Eq + Hash, - V: Debug, -{ - pub fn new() -> Self { - Self { - vec: SmallVec::new(), - map: OnceLock::new(), - } - } - - pub fn insert(&mut self, key: K, value: V) { - if let Some(map) = self.map.get_mut() { - map.insert(key.clone(), self.vec.len()); - } - self.vec.push((key, value)); - } - - pub fn len(&self) -> usize { - self.vec.len() - } - - pub fn get(&self, key: &Q) -> Option<&V> - where - K: Borrow + PartialEq, - Q: Hash + Eq, - { - let map = self.map.get_or_init(|| { - self.vec - .iter() - .enumerate() - .map(|(index, (key, _))| (key.clone(), index)) - .collect() - }); - map.get(key).map(|&i| &self.vec[i].1) - } - - pub fn keys(&self) -> impl Iterator { - self.vec.iter().map(|(k, _)| k) - } - - pub fn iter(&self) -> SliceIter<'_, (K, V)> { - self.vec.iter() - } -} diff --git a/src/lib.rs b/src/lib.rs index b241cdb8a..f969c0657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ extern crate core; use std::sync::OnceLock; +use pyo3::exceptions::PyTypeError; +use pyo3::types::{PyByteArray, PyBytes, PyString}; use pyo3::{prelude::*, sync::GILOnceCell}; // parse this first to get access to the contained macro @@ -15,7 +17,6 @@ mod build_tools; mod definitions; mod errors; mod input; -mod lazy_index_map; mod lookup_key; mod recursion_guard; mod serializers; @@ -36,6 +37,19 @@ pub use serializers::{ }; pub use validators::{validate_core_schema, PySome, SchemaValidator}; +#[pyfunction(signature = (data, *, allow_inf_nan=true))] +pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult { + if let Ok(py_bytes) = data.downcast::() { + jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan) + } else if let Ok(py_str) = data.downcast::() { + jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan) + } else if let Ok(py_byte_array) = data.downcast::() { + jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan) + } else { + Err(PyTypeError::new_err("Expected bytes, bytearray or str")) + } +} + pub fn get_pydantic_core_version() -> &'static str { static PYDANTIC_CORE_VERSION: OnceLock = OnceLock::new(); @@ -95,6 +109,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(to_json, m)?)?; + m.add_function(wrap_pyfunction!(from_json, m)?)?; m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; diff --git a/src/lookup_key.rs b/src/lookup_key.rs index bb7d7e3d7..f833c00af 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -5,9 +5,11 @@ use pyo3::exceptions::{PyAttributeError, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyMapping, PyString}; +use jiter::{JsonObject, JsonValue}; + use crate::build_tools::py_schema_err; use crate::errors::{py_err_string, ErrorType, ValError, ValLineError, ValResult}; -use crate::input::{Input, JsonInput, JsonObject, StringMapping}; +use crate::input::{Input, StringMapping}; use crate::tools::{extract_i64, py_err}; /// Used for getting items from python dicts, python objects, or JSON objects, in different ways @@ -264,7 +266,7 @@ impl LookupKey { pub fn json_get<'data, 's>( &'s self, dict: &'data JsonObject, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonInput)>> { + ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> { match self { Self::Simple { key, path, .. } => match dict.get(key) { Some(value) => Ok(Some((path, value))), @@ -289,13 +291,13 @@ impl LookupKey { // first step is different from the rest as we already know dict is JsonObject // because of above checks, we know that path should have at least one element, hence unwrap - let v: &JsonInput = match path_iter.next().unwrap().json_obj_get(dict) { + let v: &JsonValue = match path_iter.next().unwrap().json_obj_get(dict) { Some(v) => v, None => continue, }; // similar to above - // iterate over the path and plug each value into the JsonInput from the last step, starting with v + // iterate over the path and plug each value into the JsonValue from the last step, starting with v // from the first step, this could just be a loop but should be somewhat faster with a functional design if let Some(v) = path_iter.try_fold(v, |d, loc| loc.json_get(d)) { // Successfully found an item, return it @@ -481,10 +483,10 @@ impl PathItem { } } - pub fn json_get<'a>(&self, any_json: &'a JsonInput) -> Option<&'a JsonInput> { + pub fn json_get<'a>(&self, any_json: &'a JsonValue) -> Option<&'a JsonValue> { match any_json { - JsonInput::Object(v_obj) => self.json_obj_get(v_obj), - JsonInput::Array(v_array) => match self { + JsonValue::Object(v_obj) => self.json_obj_get(v_obj), + JsonValue::Array(v_array) => match self { Self::Pos(index) => v_array.get(*index), Self::Neg(index) => { if let Some(index) = v_array.len().checked_sub(*index) { @@ -499,7 +501,7 @@ impl PathItem { } } - pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonInput> { + pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonValue> { match self { Self::S(key, _) => json_obj.get(key), _ => None, diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 7f406ba16..748b13338 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -6,7 +6,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index b18faea2c..d93441ce0 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{BorrowInput, GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 5026afba3..c7df345ed 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; -use crate::errors::{ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::{ DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, diff --git a/src/validators/function.rs b/src/validators/function.rs index adb143696..66bbafbb9 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -1,16 +1,16 @@ use std::sync::Arc; -use pyo3::exceptions::{PyAssertionError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyAssertionError, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::errors::{ - ErrorType, LocItem, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, + AsLocItem, ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, }; use crate::input::Input; use crate::py_gc::PyGcTraverse; -use crate::tools::{function_name, py_err, safe_repr, SchemaDict}; +use crate::tools::{function_name, safe_repr, SchemaDict}; use crate::PydanticUseDefault; use super::generator::InternalValidator; @@ -406,13 +406,7 @@ struct ValidatorCallable { #[pymethods] impl ValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate(py, input_value, outer_location) } @@ -440,13 +434,7 @@ struct AssignmentValidatorCallable { #[pymethods] impl AssignmentValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = match outer_location { - Some(ol) => match LocItem::try_from(ol) { - Ok(ol) => Some(ol), - Err(_) => return py_err!(PyTypeError; "outer_location must be a str or int"), - }, - None => None, - }; + let outer_location = outer_location.map(AsLocItem::as_loc_item); self.validator.validate_assignment( py, input_value, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index b79145c97..17ec81670 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index dab492da7..5839959e7 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -8,7 +8,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, diff --git a/src/validators/union.rs b/src/validators/union.rs index a8bd29d7d..837114408 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -9,7 +9,7 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; -use crate::errors::{ErrorType, LocItem, ValError, ValLineError, ValResult}; +use crate::errors::{AsLocItem, ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input}; use crate::lookup_key::LookupKey; use crate::py_gc::PyGcTraverse; @@ -566,7 +566,7 @@ impl TaggedUnionValidator { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { return match validator.validate(py, input, state) { Ok(res) => Ok(res), - Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)), + Err(err) => Err(err.with_outer_location(tag.as_loc_item())), }; } match self.custom_error { diff --git a/tests/test_json.py b/tests/test_json.py index 9bba05c14..4ef8a1d40 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -4,7 +4,7 @@ from typing import List import pytest -from dirty_equals import IsList +from dirty_equals import IsFloatNan, IsList import pydantic_core from pydantic_core import ( @@ -358,3 +358,10 @@ def test_bad_repr(): to_json(b) assert to_json(b, serialize_unknown=True) == b'""' + + +def test_inf_nan_allow(): + v = SchemaValidator(core_schema.float_schema(allow_inf_nan=True)) + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + assert v.validate_json('NaN') == IsFloatNan() diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 43b3d19b9..b9fabeaed 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -140,7 +140,8 @@ def test_decimal_strict_json(input_value, expected): {'ge': 0}, -0.1, Err( - 'Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-0.1, input_type=float]' + 'Input should be greater than or equal to 0 ' + '[type=greater_than_equal, input_value=-0.1, input_type=float]' ), ), ({'gt': 0}, 0.1, Decimal('0.1')), @@ -150,14 +151,14 @@ def test_decimal_strict_json(input_value, expected): ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), ({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')), + ({'allow_inf_nan': True}, float('-inf'), Decimal('-inf')), + ({'allow_inf_nan': True}, float('nan'), FunctionCheck(math.isnan)), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), ], ) def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'decimal', **kwargs}) - if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): - expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index b18181fbb..35b04c3f9 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -4,9 +4,9 @@ from typing import Any, Dict import pytest -from dirty_equals import FunctionCheck, IsStr +from dirty_equals import FunctionCheck, IsFloatNan, IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -92,8 +92,6 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected): ) def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'float', **kwargs}) - if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): - expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) @@ -376,3 +374,34 @@ def test_string_with_underscores() -> None: v.validate_python(edge_case) with pytest.raises(ValidationError): v.validate_json(f'"{edge_case}"') + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.float_schema()) + + assert v.validate_json('123') == 123 + assert v.validate_json('NaN') == IsFloatNan() + assert v.validate_json('Infinity') == float('inf') + assert v.validate_json('-Infinity') == float('-inf') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.float_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError) as exc_info1: + v.validate_json('NaN') + # insert_assert(exc_info.value.errors()) + assert exc_info1.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': IsFloatNan()} + ] + with pytest.raises(ValidationError) as exc_info2: + v.validate_json('Infinity') + assert exc_info2.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('inf')} + ] + with pytest.raises(ValidationError) as exc_info3: + v.validate_json('-Infinity') + assert exc_info3.value.errors(include_url=False) == [ + {'type': 'finite_number', 'loc': (), 'msg': 'Input should be a finite number', 'input': float('-inf')} + ] diff --git a/tests/validators/test_function.py b/tests/validators/test_function.py index 9f94ceb1b..e5ccba1e3 100644 --- a/tests/validators/test_function.py +++ b/tests/validators/test_function.py @@ -289,8 +289,19 @@ def f(input_value, validator, info): v = SchemaValidator(core_schema.with_info_wrap_validator_function(f, core_schema.int_schema())) - with pytest.raises(TypeError, match='^outer_location must be a str or int$'): - v.validate_python(4) + assert v.validate_python(4) == 6 + + with pytest.raises(ValidationError) as exc_info: + v.validate_python('wrong') + # insert_assert(exc_info.value.errors(include_url=False)) + assert exc_info.value.errors(include_url=False) == [ + { + 'type': 'int_parsing', + 'loc': ("('4',)",), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'wrong', + } + ] def test_function_after(): diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index dedc2bd93..61acab7fb 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -6,7 +6,7 @@ import pytest from dirty_equals import IsStr -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, plain_repr @@ -472,3 +472,27 @@ class PlainEnum(Enum): v_lax = v.validate_python(PlainEnum.ONE) assert v_lax == 1 assert type(v_lax) == int + + +def test_allow_inf_nan_true_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=True)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity') + + +def test_allow_inf_nan_false_json() -> None: + v = SchemaValidator(core_schema.int_schema(), core_schema.CoreConfig(allow_inf_nan=False)) + + assert v.validate_json('123') == 123 + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('NaN') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('Infinity') + with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): + v.validate_json('-Infinity') From 3eeebf37336f189843c472ee1c1b89564b7b2777 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:06:54 +0000 Subject: [PATCH 101/550] Bump black from 23.10.0 to 23.10.1 (#1057) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 79a41f7ea..655ff8bda 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -black==23.10.0 +black==23.10.1 griffe==0.36.9 pyright==1.1.334 ruff==0.1.3 From 828eb33fc58d65456dd192e29d0bd4d8631419fd Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 12:11:55 +0000 Subject: [PATCH 102/550] Bump serde_json from 1.0.107 to 1.0.108 (#1058) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f16b93697..5af1ec727 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -503,9 +503,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.107" +version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index 211fdfe01..bbeaabeac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" -serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.190", features = ["derive"] } speedate = "0.12.0" From d43d0876749df74f9c341bdc91b4c1e47b69cb38 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 6 Nov 2023 13:20:14 +0000 Subject: [PATCH 103/550] Uprev to 2.12.0 (#1061) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5af1ec727..4768d42d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.11.0" +version = "2.12.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index bbeaabeac..d5b4ca605 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.11.0" +version = "2.12.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 1cf1c75fac74b77a730d183dec9bc4a6a700fde4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 6 Nov 2023 13:20:36 +0000 Subject: [PATCH 104/550] uprev speedate, prevent - sign as datetime (#1060) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- tests/validators/test_date.py | 2 ++ tests/validators/test_datetime.py | 2 ++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4768d42d5..c6deb4246 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -521,9 +521,9 @@ checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" [[package]] name = "speedate" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c028e117e67c1f3224f5f834b3e48d4133dc11ec509aa19fdfa6c0987efed332" +checksum = "242f76c50fd18cbf098607090ade73a08d39cfd84ea835f3796a2c855223b19b" dependencies = [ "strum", "strum_macros", diff --git a/Cargo.toml b/Cargo.toml index d5b4ca605..242479712 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ strum_macros = "0.25.3" serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.190", features = ["derive"] } -speedate = "0.12.0" +speedate = "0.13.0" smallvec = "1.11.1" ahash = "0.8.6" url = "2.4.1" diff --git a/tests/validators/test_date.py b/tests/validators/test_date.py index 5ddde4884..6a552a57b 100644 --- a/tests/validators/test_date.py +++ b/tests/validators/test_date.py @@ -64,6 +64,8 @@ ), id='-inf', ), + pytest.param('-', Err('Input should be a valid date or datetime, input is too short'), id='minus'), + pytest.param('+', Err('Input should be a valid date or datetime, input is too short'), id='pus'), ], ) def test_date(input_value, expected): diff --git a/tests/validators/test_datetime.py b/tests/validators/test_datetime.py index 67581119b..89e9c1c53 100644 --- a/tests/validators/test_datetime.py +++ b/tests/validators/test_datetime.py @@ -36,6 +36,8 @@ (float('nan'), Err('Input should be a valid datetime, NaN values not permitted [type=datetime_parsing,')), (float('inf'), Err('Input should be a valid datetime, dates after 9999')), (float('-inf'), Err('Input should be a valid datetime, dates before 1600')), + ('-', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), + ('+', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), ], ) def test_datetime(input_value, expected): From cb47b966923e8a4c5b9bba9b6f3ccb1818bf2231 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 6 Nov 2023 13:27:48 +0000 Subject: [PATCH 105/550] check not type in serialization (#962) --- src/serializers/ob_type.rs | 2 ++ tests/serializers/test_any.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 109aed3bb..ff43a1065 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -333,6 +333,7 @@ fn is_dataclass(op_value: Option<&PyAny>) -> bool { value .hasattr(intern!(value.py(), "__dataclass_fields__")) .unwrap_or(false) + && !value.is_instance_of::() } else { false } @@ -343,6 +344,7 @@ fn is_pydantic_serializable(op_value: Option<&PyAny>) -> bool { value .hasattr(intern!(value.py(), "__pydantic_serializer__")) .unwrap_or(false) + && !value.is_instance_of::() } else { false } diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 448ca108a..98ec22c1f 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -505,6 +505,14 @@ class Foo: assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() assert s.to_json(Foo(a='hello', b=b'more'), exclude={'a'}) == b'{}' + assert s.to_python(Foo) == Foo + with pytest.raises(PydanticSerializationError, match=r"Unable to serialize unknown type: "): + s.to_python(Foo, mode='json') + with pytest.raises(PydanticSerializationError, match=r"Unable to serialize unknown type: "): + s.to_json(Foo) + assert s.to_python(Foo, mode='json', fallback=lambda x: x.__name__) == 'Foo' + assert s.to_json(Foo, fallback=lambda x: x.__name__) == b'"Foo"' + def test_dataclass_classvar(any_serializer): @dataclasses.dataclass From 5de6b75b2fbdf6d3cb48e9f59909e07326698f69 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 6 Nov 2023 18:04:59 +0000 Subject: [PATCH 106/550] reduce dependabot frequency (#1059) --- .github/dependabot.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index e9157d2f2..b93ab648d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,14 +3,14 @@ updates: - package-ecosystem: "cargo" directory: "/" schedule: - interval: "weekly" + interval: "monthly" - package-ecosystem: "pip" directory: "/" schedule: - interval: "weekly" + interval: "monthly" - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "weekly" + interval: "monthly" From d80c454bbfc30a9aa3ee6d3315a339caf3ef27f6 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 6 Nov 2023 18:24:51 +0000 Subject: [PATCH 107/550] fix: nan inf float (#1062) Co-authored-by: JeanArhancet --- python/pydantic_core/core_schema.py | 3 + src/errors/validation_exception.rs | 4 +- src/serializers/config.rs | 29 + src/serializers/errors.rs | 27 +- src/serializers/mod.rs | 1 + src/serializers/ser.rs | 1299 ++++++++++++++++++++ src/serializers/shared.rs | 7 +- src/serializers/type_serializers/float.rs | 102 ++ src/serializers/type_serializers/mod.rs | 1 + src/serializers/type_serializers/simple.rs | 1 - tests/serializers/test_simple.py | 27 + tests/test_errors.py | 2 +- tests/validators/test_float.py | 1 + 13 files changed, 1496 insertions(+), 8 deletions(-) create mode 100644 src/serializers/ser.rs create mode 100644 src/serializers/type_serializers/float.rs diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 40d1eec77..fec3b9966 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -68,6 +68,8 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`. ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. + ser_json_inf_nan: The serialization option for infinity and NaN values + in float fields. Default is 'null'. hide_input_in_errors: Whether to hide input data from `ValidationError` representation. validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. Requires exceptiongroup backport pre Python 3.11. @@ -102,6 +104,7 @@ class CoreConfig(TypedDict, total=False): # the config options are used to customise serialization to JSON ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601' ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' + ser_json_inf_nan: Literal['null', 'constants'] # default: 'null' # used to hide input data from ValidationError repr hide_input_in_errors: bool validation_error_cause: bool # default: False diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index cb3802d56..d616e3022 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -324,12 +324,12 @@ impl ValidationError { Some(indent) => { let indent = vec![b' '; indent]; let formatter = PrettyFormatter::with_indent(&indent); - let mut ser = serde_json::Serializer::with_formatter(writer, formatter); + let mut ser = crate::serializers::ser::PythonSerializer::with_formatter(writer, formatter); serializer.serialize(&mut ser).map_err(json_py_err)?; ser.into_inner() } None => { - let mut ser = serde_json::Serializer::new(writer); + let mut ser = crate::serializers::ser::PythonSerializer::new(writer); serializer.serialize(&mut ser).map_err(json_py_err)?; ser.into_inner() } diff --git a/src/serializers/config.rs b/src/serializers/config.rs index fb65623b7..e83497f64 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -189,3 +189,32 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr { Err(err) => err, } } + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub(crate) enum InfNanMode { + #[default] + Null, + Constants, +} + +impl FromStr for InfNanMode { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + "null" => Ok(Self::Null), + "constants" => Ok(Self::Constants), + s => py_schema_err!( + "Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`", + s + ), + } + } +} + +impl FromPyObject<'_> for InfNanMode { + fn extract(ob: &'_ PyAny) -> PyResult { + let s = ob.extract::<&str>()?; + Self::from_str(s) + } +} diff --git a/src/serializers/errors.rs b/src/serializers/errors.rs index ac4ea784f..71a0a024e 100644 --- a/src/serializers/errors.rs +++ b/src/serializers/errors.rs @@ -14,8 +14,33 @@ pub(super) fn py_err_se_err(py_error: E) -> T { T::custom(py_error.to_string()) } +#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")] +#[derive(Debug, Clone)] +pub struct PythonSerializerError { + pub message: String, +} + +impl fmt::Display for PythonSerializerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for PythonSerializerError {} + +impl serde::ser::Error for PythonSerializerError { + fn custom(msg: T) -> Self + where + T: fmt::Display, + { + PythonSerializerError { + message: format!("{msg}"), + } + } +} + /// convert a serde serialization error into a `PyErr` -pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr { +pub(super) fn se_err_py_err(error: PythonSerializerError) -> PyErr { let s = error.to_string(); if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) { if msg.is_empty() { diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 00d70162f..e9208a510 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -23,6 +23,7 @@ mod fields; mod filter; mod infer; mod ob_type; +pub mod ser; mod shared; mod type_serializers; diff --git a/src/serializers/ser.rs b/src/serializers/ser.rs new file mode 100644 index 000000000..170cd1849 --- /dev/null +++ b/src/serializers/ser.rs @@ -0,0 +1,1299 @@ +use std::{io, num::FpCategory}; + +use serde::{ser::Impossible, serde_if_integer128, Serialize, Serializer}; +use serde_json::ser::{CompactFormatter, Formatter, PrettyFormatter, State}; + +use super::errors::PythonSerializerError; + +macro_rules! tri { + ($e:expr $(,)?) => { + match $e { + core::result::Result::Ok(val) => val, + core::result::Result::Err(err) => return core::result::Result::Err(err), + } + }; +} + +type Result = std::result::Result; +const TOKEN: &str = "$serde_json::private::Number"; +pub struct PythonSerializer { + writer: W, + formatter: F, +} + +impl PythonSerializer +where + W: io::Write, +{ + /// Creates a new JSON serializer. + #[inline] + pub fn new(writer: W) -> Self { + PythonSerializer::with_formatter(writer, CompactFormatter) + } +} + +impl<'a, W> PythonSerializer> +where + W: io::Write, +{ + /// Creates a new JSON pretty print serializer. + #[inline] + pub fn pretty(writer: W) -> Self { + PythonSerializer::with_formatter(writer, PrettyFormatter::new()) + } +} + +impl PythonSerializer +where + W: io::Write, + F: Formatter, +{ + /// Creates a new JSON visitor whose output will be written to the writer + /// specified. + #[inline] + pub fn with_formatter(writer: W, formatter: F) -> Self { + PythonSerializer { writer, formatter } + } + + /// Unwrap the `Writer` from the `Serializer`. + #[inline] + pub fn into_inner(self) -> W { + self.writer + } +} + +impl<'a, W, F> Serializer for &'a mut PythonSerializer +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + type SerializeSeq = Compound<'a, W, F>; + type SerializeTuple = Compound<'a, W, F>; + type SerializeTupleStruct = Compound<'a, W, F>; + type SerializeTupleVariant = Compound<'a, W, F>; + type SerializeMap = Compound<'a, W, F>; + type SerializeStruct = Compound<'a, W, F>; + type SerializeStructVariant = Compound<'a, W, F>; + + #[inline] + fn serialize_bool(self, value: bool) -> Result<()> { + self.formatter + .write_bool(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + #[inline] + fn serialize_i8(self, value: i8) -> Result<()> { + self.formatter + .write_i8(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i16(self, value: i16) -> Result { + self.formatter + .write_i16(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i32(self, value: i32) -> Result { + self.formatter + .write_i32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_i64(self, value: i64) -> Result { + self.formatter + .write_i64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u8(self, value: u8) -> Result { + self.formatter + .write_u8(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u16(self, value: u16) -> Result { + self.formatter + .write_u16(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u32(self, value: u32) -> Result { + self.formatter + .write_u32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u64(self, value: u64) -> Result { + self.formatter + .write_u64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_u128(self, value: u128) -> Result<()> { + self.formatter + .write_u128(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + #[inline] + fn serialize_f32(self, value: f32) -> Result<()> { + match value.classify() { + FpCategory::Nan => self + .formatter + .write_number_str(&mut self.writer, "NaN") + .map_err(|e| PythonSerializerError { message: e.to_string() }), + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-Infinity" + } else { + "Infinity" + }; + self.formatter + .write_number_str(&mut self.writer, infinity) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + _ => self + .formatter + .write_f32(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }), + } + } + + fn serialize_f64(self, value: f64) -> Result { + match value.classify() { + FpCategory::Nan => self + .formatter + .write_number_str(&mut self.writer, "NaN") + .map_err(|e| PythonSerializerError { message: e.to_string() }), + FpCategory::Infinite => { + let infinity = if value.is_sign_negative() { + "-Infinity" + } else { + "Infinity" + }; + self.formatter + .write_number_str(&mut self.writer, infinity) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + _ => self + .formatter + .write_f64(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }), + } + } + + fn serialize_char(self, value: char) -> Result { + // A char encoded as UTF-8 takes 4 bytes at most. + let mut buf = [0; 4]; + self.serialize_str(value.encode_utf8(&mut buf)) + } + + fn serialize_str(self, value: &str) -> Result { + format_escaped_str(&mut self.writer, &mut self.formatter, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_bytes(self, value: &[u8]) -> Result<()> { + self.formatter + .write_byte_array(&mut self.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_none(self) -> Result { + self.formatter + .write_null(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_some(self, value: &T) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + self.formatter + .write_null(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(value.serialize(&mut *self)); + tri!(self + .formatter + .end_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.formatter + .end_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_seq(self, len: Option) -> Result { + tri!(self + .formatter + .begin_array(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + if len == Some(0) { + tri!(self + .formatter + .end_array(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(Compound::Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Compound::Map { + ser: self, + state: State::First, + }) + } + } + + fn serialize_tuple(self, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct(self, _name: &'static str, len: usize) -> Result { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.serialize_seq(Some(len)) + } + + fn serialize_map(self, len: Option) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + if len == Some(0) { + tri!(self + .formatter + .end_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(Compound::Map { + ser: self, + state: State::Empty, + }) + } else { + Ok(Compound::Map { + ser: self, + state: State::First, + }) + } + } + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + match name { + TOKEN => Ok(Compound::Number { ser: self }), + _ => self.serialize_map(Some(len)), + } + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + tri!(self + .formatter + .begin_object(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_key(&mut self.writer, true) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self.serialize_str(variant)); + tri!(self + .formatter + .end_object_key(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .formatter + .begin_object_value(&mut self.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + self.serialize_map(Some(len)) + } +} + +impl<'a, W, F> serde::ser::SerializeSeq for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, state } => { + tri!(ser + .formatter + .begin_array_value(&mut ser.writer, *state == State::First) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + *state = State::Rest; + tri!(value.serialize(&mut **ser)); + tri!(ser + .formatter + .end_array_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_array(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeTuple for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'a, W, F> serde::ser::SerializeTupleStruct for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'a, W, F> serde::ser::SerializeTupleVariant for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_array(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeMap for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, state } => { + tri!(ser + .formatter + .begin_object_key(&mut ser.writer, *state == State::First) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + *state = State::Rest; + + tri!(key.serialize(MapKeySerializer { ser: *ser })); + + tri!(ser + .formatter + .end_object_key(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { ser, .. } => { + tri!(ser + .formatter + .begin_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(value.serialize(&mut **ser)); + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +impl<'a, W, F> serde::ser::SerializeStruct for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match self { + Compound::Map { .. } => serde::ser::SerializeMap::serialize_entry(self, key, value), + Compound::Number { ser, .. } => { + if key == TOKEN { + tri!(value.serialize(NumberStrEmitter(ser))); + Ok(()) + } else { + Err(invalid_number()) + } + } + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { .. } => serde::ser::SerializeMap::end(self), + Compound::Number { .. } => Ok(()), + } + } +} + +impl<'a, W, F> serde::ser::SerializeStructVariant for Compound<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + match *self { + Compound::Map { .. } => serde::ser::SerializeStruct::serialize_field(self, key, value), + Compound::Number { .. } => unreachable!(), + } + } + + #[inline] + fn end(self) -> Result<()> { + match self { + Compound::Map { ser, state } => { + match state { + State::Empty => {} + _ => tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })), + } + tri!(ser + .formatter + .end_object_value(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(ser + .formatter + .end_object(&mut ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + Compound::Number { .. } => unreachable!(), + } + } +} + +fn format_escaped_str(writer: &mut W, formatter: &mut F, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + tri!(formatter.begin_string(writer)); + tri!(format_escaped_str_contents(writer, formatter, value)); + formatter.end_string(writer) +} + +fn format_escaped_str_contents(writer: &mut W, formatter: &mut F, value: &str) -> io::Result<()> +where + W: ?Sized + io::Write, + F: ?Sized + Formatter, +{ + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + tri!(formatter.write_string_fragment(writer, &value[start..i])); + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + tri!(formatter.write_char_escape(writer, char_escape)); + + start = i + 1; + } + + if start == bytes.len() { + return Ok(()); + } + + formatter.write_string_fragment(writer, &value[start..]) +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +pub enum Compound<'a, W: 'a, F: 'a> { + Map { + ser: &'a mut PythonSerializer, + state: State, + }, + Number { + ser: &'a mut PythonSerializer, + }, +} + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape {} + +impl CharEscape { + #[inline] + fn from_escape_table(escape: u8, byte: u8) -> serde_json::ser::CharEscape { + match escape { + self::BB => serde_json::ser::CharEscape::Backspace, + self::TT => serde_json::ser::CharEscape::Tab, + self::NN => serde_json::ser::CharEscape::LineFeed, + self::FF => serde_json::ser::CharEscape::FormFeed, + self::RR => serde_json::ser::CharEscape::CarriageReturn, + self::QU => serde_json::ser::CharEscape::Quote, + self::BS => serde_json::ser::CharEscape::ReverseSolidus, + self::UU => serde_json::ser::CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +struct MapKeySerializer<'a, W: 'a, F: 'a> { + ser: &'a mut PythonSerializer, +} + +fn key_must_be_a_string() -> PythonSerializerError { + PythonSerializerError { + message: "Key must be a string".to_string(), + } +} +fn invalid_number() -> PythonSerializerError { + PythonSerializerError { + message: "Invalid Number".to_string(), + } +} + +impl<'a, W, F> serde::ser::Serializer for MapKeySerializer<'a, W, F> +where + W: io::Write, + F: Formatter, +{ + type Ok = (); + type Error = PythonSerializerError; + + #[inline] + fn serialize_str(self, value: &str) -> Result<()> { + self.ser.serialize_str(value) + } + + #[inline] + fn serialize_unit_variant(self, _name: &'static str, _variant_index: u32, variant: &'static str) -> Result<()> { + self.ser.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + type SerializeSeq = Impossible<(), PythonSerializerError>; + type SerializeTuple = Impossible<(), PythonSerializerError>; + type SerializeTupleStruct = Impossible<(), PythonSerializerError>; + type SerializeTupleVariant = Impossible<(), PythonSerializerError>; + type SerializeMap = Impossible<(), PythonSerializerError>; + type SerializeStruct = Impossible<(), PythonSerializerError>; + type SerializeStructVariant = Impossible<(), PythonSerializerError>; + + fn serialize_bool(self, _value: bool) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_i8(self, value: i8) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i8(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i16(self, value: i16) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i16(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i32(self, value: i32) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i32(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_i64(self, value: i64) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_i64(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + serde_if_integer128! { + fn serialize_i128(self, value: i128) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_number_str(&mut self.ser.writer, &value.to_string()) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + } + + fn serialize_u8(self, value: u8) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u8(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u16(self, value: u16) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u16(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u32(self, value: u32) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u32(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + fn serialize_u64(self, value: u64) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_u64(&mut self.ser.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + + serde_if_integer128! { + fn serialize_u128(self, value: u128) -> Result<()> { + tri!(self + .ser + .formatter + .begin_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .write_number_str(&mut self.ser.writer, &value.to_string()) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + tri!(self + .ser + .formatter + .end_string(&mut self.ser.writer) + .map_err(|e| PythonSerializerError { message: e.to_string() })); + Ok(()) + } + } + + fn serialize_f32(self, _value: f32) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_f64(self, _value: f64) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_char(self, value: char) -> Result<()> { + self.ser.serialize_str(&value.to_string()) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_unit(self) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(key_must_be_a_string()) + } + + fn serialize_none(self) -> Result<()> { + Err(key_must_be_a_string()) + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(key_must_be_a_string()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(key_must_be_a_string()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(key_must_be_a_string()) + } + + fn collect_str(self, value: &T) -> Result<()> + where + T: ?Sized + std::fmt::Display, + { + self.ser.collect_str(value) + } +} + +struct NumberStrEmitter<'a, W: 'a + io::Write, F: 'a + Formatter>(&'a mut PythonSerializer); + +impl<'a, W: io::Write, F: Formatter> serde::ser::Serializer for NumberStrEmitter<'a, W, F> { + type Ok = (); + type Error = PythonSerializerError; + + type SerializeSeq = Impossible<(), PythonSerializerError>; + type SerializeTuple = Impossible<(), PythonSerializerError>; + type SerializeTupleStruct = Impossible<(), PythonSerializerError>; + type SerializeTupleVariant = Impossible<(), PythonSerializerError>; + type SerializeMap = Impossible<(), PythonSerializerError>; + type SerializeStruct = Impossible<(), PythonSerializerError>; + type SerializeStructVariant = Impossible<(), PythonSerializerError>; + + fn serialize_bool(self, _v: bool) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i8(self, _v: i8) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i16(self, _v: i16) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i32(self, _v: i32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_i64(self, _v: i64) -> Result<()> { + Err(invalid_number()) + } + + serde_if_integer128! { + fn serialize_i128(self, _v: i128) -> Result<()> { + Err(invalid_number()) + } + } + + fn serialize_u8(self, _v: u8) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u16(self, _v: u16) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u32(self, _v: u32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_u64(self, _v: u64) -> Result<()> { + Err(invalid_number()) + } + + serde_if_integer128! { + fn serialize_u128(self, _v: u128) -> Result<()> { + Err(invalid_number()) + } + } + + fn serialize_f32(self, _v: f32) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_f64(self, _v: f64) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_char(self, _v: char) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_str(self, value: &str) -> Result<()> { + let NumberStrEmitter(serializer) = self; + serializer + .formatter + .write_number_str(&mut serializer.writer, value) + .map_err(|e| PythonSerializerError { message: e.to_string() }) + } + + fn serialize_bytes(self, _value: &[u8]) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_none(self) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_some(self, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_unit(self) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_unit_variant(self, _name: &'static str, _variant_index: u32, _variant: &'static str) -> Result<()> { + Err(invalid_number()) + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + Err(invalid_number()) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(invalid_number()) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(invalid_number()) + } + + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(invalid_number()) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(invalid_number()) + } +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 7c24ff6db..cfccc748a 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -15,6 +15,7 @@ use crate::build_tools::py_schema_err; use crate::build_tools::py_schema_error_type; use crate::definitions::DefinitionsBuilder; use crate::py_gc::PyGcTraverse; +use crate::serializers::ser::PythonSerializer; use crate::tools::{py_err, SchemaDict}; use super::errors::se_err_py_err; @@ -112,7 +113,7 @@ combined_serializer! { Nullable: super::type_serializers::nullable::NullableSerializer; Int: super::type_serializers::simple::IntSerializer; Bool: super::type_serializers::simple::BoolSerializer; - Float: super::type_serializers::simple::FloatSerializer; + Float: super::type_serializers::float::FloatSerializer; Decimal: super::type_serializers::decimal::DecimalSerializer; Str: super::type_serializers::string::StrSerializer; Bytes: super::type_serializers::bytes::BytesSerializer; @@ -352,12 +353,12 @@ pub(crate) fn to_json_bytes( Some(indent) => { let indent = vec![b' '; indent]; let formatter = PrettyFormatter::with_indent(&indent); - let mut ser = serde_json::Serializer::with_formatter(writer, formatter); + let mut ser = PythonSerializer::with_formatter(writer, formatter); serializer.serialize(&mut ser).map_err(se_err_py_err)?; ser.into_inner() } None => { - let mut ser = serde_json::Serializer::new(writer); + let mut ser = PythonSerializer::new(writer); serializer.serialize(&mut ser).map_err(se_err_py_err)?; ser.into_inner() } diff --git a/src/serializers/type_serializers/float.rs b/src/serializers/type_serializers/float.rs new file mode 100644 index 000000000..23dcacf1a --- /dev/null +++ b/src/serializers/type_serializers/float.rs @@ -0,0 +1,102 @@ +use pyo3::types::PyDict; +use pyo3::{intern, prelude::*}; + +use std::borrow::Cow; + +use serde::Serializer; + +use crate::definitions::DefinitionsBuilder; +use crate::serializers::config::InfNanMode; +use crate::tools::SchemaDict; + +use super::simple::to_str_json_key; +use super::{ + infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, + SerMode, TypeSerializer, +}; + +#[derive(Debug, Clone)] +pub struct FloatSerializer { + inf_nan_mode: InfNanMode, +} + +impl BuildSerializer for FloatSerializer { + const EXPECTED_TYPE: &'static str = "float"; + + fn build( + schema: &PyDict, + config: Option<&PyDict>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let inf_nan_mode = config + .and_then(|c| c.get_as(intern!(schema.py(), "ser_json_inf_nan")).transpose()) + .transpose()? + .unwrap_or_default(); + Ok(Self { inf_nan_mode }.into()) + } +} + +impl_py_gc_traverse!(FloatSerializer {}); + +impl TypeSerializer for FloatSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + match extra.ob_type_lookup.is_type(value, ObType::Float) { + IsType::Exact => Ok(value.into_py(py)), + IsType::Subclass => match extra.mode { + SerMode::Json => { + let rust_value = value.extract::()?; + Ok(rust_value.to_object(py)) + } + _ => infer_to_python(value, include, exclude, extra), + }, + IsType::False => { + extra.warnings.on_fallback_py(self.get_name(), value, extra)?; + infer_to_python(value, include, exclude, extra) + } + } + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + match extra.ob_type_lookup.is_type(key, ObType::Float) { + IsType::Exact | IsType::Subclass => to_str_json_key(key), + IsType::False => { + extra.warnings.on_fallback_py(self.get_name(), key, extra)?; + infer_json_key(key, extra) + } + } + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + match value.extract::() { + Ok(v) => { + if (v.is_nan() || v.is_infinite()) && self.inf_nan_mode == InfNanMode::Null { + serializer.serialize_none() + } else { + serializer.serialize_f64(v) + } + } + Err(_) => { + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) + } + } + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index b942b5b86..decb07aaf 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -5,6 +5,7 @@ pub mod datetime_etc; pub mod decimal; pub mod definitions; pub mod dict; +pub mod float; pub mod format; pub mod function; pub mod generator; diff --git a/src/serializers/type_serializers/simple.rs b/src/serializers/type_serializers/simple.rs index f0d90c2bf..dafb2b786 100644 --- a/src/serializers/type_serializers/simple.rs +++ b/src/serializers/type_serializers/simple.rs @@ -180,4 +180,3 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { } build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key); -build_simple_serializer!(FloatSerializer, "float", f64, ObType::Float, to_str_json_key); diff --git a/tests/serializers/test_simple.py b/tests/serializers/test_simple.py index 9bbaad05b..b63208c07 100644 --- a/tests/serializers/test_simple.py +++ b/tests/serializers/test_simple.py @@ -136,3 +136,30 @@ def test_numpy(): assert type(v) == float assert s.to_json(numpy.float64(1.0)) == b'1.0' + + +@pytest.mark.parametrize( + 'value,expected_json,config', + [ + # default values of ser_json_inf_nan + (float('inf'), 'null', {}), + (float('-inf'), 'null', {}), + (float('nan'), 'null', {}), + # explicit values of ser_json_inf_nan + (float('inf'), 'null', {'ser_json_inf_nan': 'null'}), + (float('-inf'), 'null', {'ser_json_inf_nan': 'null'}), + (float('nan'), 'null', {'ser_json_inf_nan': 'null'}), + (float('inf'), 'Infinity', {'ser_json_inf_nan': 'constants'}), + (float('-inf'), '-Infinity', {'ser_json_inf_nan': 'constants'}), + (float('nan'), 'NaN', {'ser_json_inf_nan': 'constants'}), + ], +) +def test_float_inf_and_nan_serializers(value, expected_json, config): + s = SchemaSerializer(core_schema.float_schema(), config) + + # Python can represent these values without needing any changes + assert s.to_python(value) is value + assert s.to_python(value, mode='json') is value + + # Serialized JSON value respects the ser_json_inf_nan setting + assert s.to_json(value).decode() == expected_json diff --git a/tests/test_errors.py b/tests/test_errors.py index 3683150ec..0d7a966e1 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -789,7 +789,7 @@ def raise_py_error(v: Any) -> Any: with pytest.raises(ValidationError) as exc_info: s.validate_python('anything') - exc = exc_info.value.errors()[0]['ctx']['error'] # type: ignore + exc = exc_info.value.errors()[0]['ctx']['error'] assert isinstance(exc, ValueError) assert isinstance(exc.__context__, AssertionError) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 35b04c3f9..4e3bda0c4 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -227,6 +227,7 @@ def test_float_key(py_and_json: PyAndJson): ('NaN', True, FunctionCheck(math.isnan)), ('NaN', False, Err("Input should be a finite number [type=finite_number, input_value='NaN', input_type=str]")), ('+inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)), + ('inf', True, FunctionCheck(lambda x: math.isinf(x) and x > 0)), ( '+inf', False, From 0c461461668d2630d86e027a087b861cccd5e0ab Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 7 Nov 2023 13:27:11 +0000 Subject: [PATCH 108/550] PGO build for MacOS M1 (#1063) Co-authored-by: David Hewitt --- .github/workflows/ci.yml | 74 +++++++++++++++++++++++++++------------- Cargo.lock | 2 +- Cargo.toml | 2 +- src/input/input_json.rs | 2 -- 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62e998bda..a451eb725 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: - '3.9' - '3.10' - '3.11' - - '3.12-dev' + - '3.12' - 'pypy3.7' - 'pypy3.8' - 'pypy3.9' @@ -389,7 +389,7 @@ jobs: interpreter: 3.11 3.12 - os: macos target: aarch64 - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 pypy3.8 pypy3.9 pypy3.10 + interpreter: 3.7 3.8 3.9 pypy3.8 pypy3.9 pypy3.10 - os: ubuntu platform: linux target: i686 @@ -465,25 +465,26 @@ jobs: path: dist build-pgo: - name: build pgo-optimized on ${{ matrix.platform || matrix.os }} (${{ matrix.interpreter}} - ${{ matrix.target }} - ${{ matrix.manylinux || 'auto' }}) + name: build pgo-optimized on ${{ matrix.os }} / ${{ matrix.interpreter }} # only run on push to main and on release if: startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || contains(github.event.pull_request.labels.*.name, 'Full Build') strategy: fail-fast: false matrix: - os: [ubuntu, windows] - target: [x86_64] - manylinux: [auto] - interpreter: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12-dev", "pypy3.7", "pypy3.8", "pypy3.9", "pypy3.10"] + os: [ubuntu-latest, windows-latest, macos-latest-xlarge] + interpreter: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] include: - - os: ubuntu - platform: linux - - os: windows + - os: windows-latest ls: dir - - interpreter: 3.12-dev - maturin-interpreter: "3.12" - - runs-on: ${{ matrix.os }}-latest + exclude: + - os: macos-latest-xlarge + interpreter: '3.7' + - os: macos-latest-xlarge + interpreter: '3.8' + - os: macos-latest-xlarge + interpreter: '3.9' + + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -491,7 +492,6 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.interpreter }} - architecture: ${{ matrix.python-architecture || 'x64' }} - name: install rust stable id: rust-toolchain @@ -504,15 +504,15 @@ jobs: # generate self-schema now, so we don't have to do so inside docker in maturin build - run: python generate_self_schema.py + - run: rustc --version --verbose + - name: build initial wheel uses: PyO3/maturin-action@v1 with: - target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux || 'auto' }} args: > --release --out pgo-wheel - --interpreter ${{ matrix.maturin-interpreter || matrix.interpreter }} + --interpreter ${{ matrix.interpreter }} rust-toolchain: stable docker-options: -e CI env: @@ -536,12 +536,10 @@ jobs: - name: build pgo-optimized wheel uses: PyO3/maturin-action@v1 with: - target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux || 'auto' }} args: > --release --out dist - --interpreter ${{ matrix.maturin-interpreter || matrix.interpreter }} + --interpreter ${{ matrix.interpreter }} rust-toolchain: stable docker-options: -e CI env: @@ -551,7 +549,7 @@ jobs: - uses: actions/upload-artifact@v3 with: - name: pypi_files + name: pypi_files_pgo path: dist inspect-pypi-assets: @@ -567,7 +565,19 @@ jobs: name: pypi_files path: dist - - name: list dist files + - name: list dist files before PGO builds + run: | + ls -lh dist/ + ls -l dist/ + echo "`ls dist | wc -l` files" + + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + + - name: list dist files with PGO builds run: | ls -lh dist/ ls -l dist/ @@ -607,6 +617,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - uses: uraimo/run-on-arch-action@v2.5.1 name: install & test with: @@ -659,6 +675,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - run: pip install typing-extensions - run: pip install -r tests/requirements.txt - run: pip install pydantic-core --no-index --no-deps --find-links dist --force-reinstall @@ -688,6 +710,12 @@ jobs: name: pypi_files path: dist + - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) + uses: actions/download-artifact@v3 + with: + name: pypi_files_pgo + path: dist + - run: twine check --strict dist/* - name: upload to pypi diff --git a/Cargo.lock b/Cargo.lock index c6deb4246..13874f1fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.12.0" +version = "2.13.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 242479712..6c1d1d948 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.12.0" +version = "2.13.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/src/input/input_json.rs b/src/input/input_json.rs index ac552621d..59c120cd0 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -368,8 +368,6 @@ impl AsLocItem for String { } } -/// TODO: it would be good to get JsonInput and StringMapping string variants to go through this -/// implementation /// Required for JSON Object keys so the string can behave like an Input impl<'a> Input<'a> for String { fn as_error_value(&'a self) -> InputValue<'a> { From d08b4f38c792849fa9da4520ae291944ccdec8e0 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 8 Nov 2023 21:38:46 +0000 Subject: [PATCH 109/550] run pydantic integration tests with lax xfail (#1054) --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a451eb725..fa9dfa0e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -214,7 +214,9 @@ jobs: - run: pdm info && pdm list working-directory: pydantic - - run: pdm run pytest + # Run pytest with lax xfail because we often add tests to pydantic + # which xfail on a pending release of pydantic-core + - run: pdm run pytest --override-ini=xfail_strict=False working-directory: pydantic lint: From f409e0013e4f7b1937c2bd14ca877f7d27541be3 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 8 Nov 2023 21:40:32 +0000 Subject: [PATCH 110/550] replace ultra_strict with new union implementation (#867) Co-authored-by: Samuel Colvin --- src/definitions.rs | 12 - src/input/input_abstract.rs | 162 +----- src/input/input_json.rs | 239 ++++----- src/input/input_python.rs | 495 +++++++++--------- src/input/input_string.rs | 68 ++- src/input/mod.rs | 2 +- src/input/return_enums.rs | 35 +- src/validators/any.rs | 16 +- src/validators/arguments.rs | 23 +- src/validators/bool.rs | 13 +- src/validators/bytes.rs | 24 +- src/validators/call.rs | 17 - src/validators/callable.rs | 12 +- src/validators/chain.rs | 8 - src/validators/custom_error.rs | 8 - src/validators/dataclass.rs | 28 +- src/validators/date.rs | 18 +- src/validators/datetime.rs | 12 +- src/validators/decimal.rs | 8 - src/validators/definitions.rs | 29 - src/validators/dict.rs | 14 +- src/validators/float.rs | 22 +- src/validators/frozenset.rs | 21 +- src/validators/function.rs | 43 +- src/validators/generator.rs | 38 +- src/validators/int.rs | 26 +- src/validators/is_instance.rs | 11 +- src/validators/is_subclass.rs | 11 +- src/validators/json.rs | 18 +- src/validators/json_or_python.rs | 13 +- src/validators/lax_or_strict.rs | 23 +- src/validators/list.rs | 25 +- src/validators/literal.rs | 16 +- src/validators/mod.rs | 27 +- src/validators/model.rs | 15 +- src/validators/model_fields.rs | 23 +- src/validators/none.rs | 8 - src/validators/nullable.rs | 8 - src/validators/set.rs | 21 +- src/validators/string.rs | 25 +- src/validators/time.rs | 12 +- src/validators/timedelta.rs | 12 +- src/validators/tuple.rs | 54 +- src/validators/typed_dict.rs | 18 +- src/validators/union.rs | 200 +++---- src/validators/url.rs | 35 +- src/validators/uuid.rs | 18 +- src/validators/validation_state.rs | 39 +- src/validators/with_default.rs | 8 - .../validators/test_definitions_recursive.py | 19 + tests/validators/test_union.py | 298 ++++++++++- 51 files changed, 1011 insertions(+), 1339 deletions(-) diff --git a/src/definitions.rs b/src/definitions.rs index 1eb813015..4627fd2d1 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -31,21 +31,9 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; #[derive(Clone)] pub struct Definitions(AHashMap, Definition>); -impl Definitions { - pub fn values(&self) -> impl Iterator> { - self.0.values() - } -} - /// Internal type which contains a definition to be filled pub struct Definition(Arc>); -impl Definition { - pub fn get(&self) -> Option<&T> { - self.0.value.get() - } -} - struct DefinitionInner { value: OnceLock, name: LazyName, diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 52551ef42..ba6fbd0a1 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -6,13 +6,13 @@ use pyo3::{intern, prelude::*}; use jiter::JsonValue; -use crate::errors::{AsLocItem, InputValue, ValResult}; +use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherInt, EitherString}; -use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping}; +use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch}; #[derive(Debug, Clone, Copy)] pub enum InputType { @@ -48,7 +48,7 @@ impl TryFrom<&str> for InputType { /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same -pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem { +pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { fn as_error_value(&'a self) -> InputValue<'a>; fn identity(&self) -> Option { @@ -91,85 +91,37 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem { fn parse_json(&'a self) -> ValResult<'a, JsonValue>; - fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult> { - if strict { - self.strict_str() - } else { - self.lax_str(coerce_numbers_to_str) - } - } - fn strict_str(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult> { - self.strict_str() - } + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>>; - fn validate_bytes(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_bytes() - } else { - self.lax_bytes() - } - } - fn strict_bytes(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_bytes(&'a self) -> ValResult> { - self.strict_bytes() - } + fn validate_bytes(&'a self, strict: bool) -> ValResult>>; - fn validate_bool(&self, strict: bool) -> ValResult { - if strict { - self.strict_bool() - } else { - self.lax_bool() - } - } - fn strict_bool(&self) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_bool(&self) -> ValResult { - self.strict_bool() - } + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch>; - fn validate_int(&'a self, strict: bool) -> ValResult> { - if strict { - self.strict_int() - } else { - self.lax_int() - } - } - fn strict_int(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_int(&'a self) -> ValResult> { - self.strict_int() - } + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; - /// Extract an EitherInt from the input, only allowing exact - /// matches for an Int (no subclasses) fn exact_int(&'a self) -> ValResult> { - self.strict_int() + self.validate_int(true).and_then(|val_match| { + val_match + .require_exact() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self)) + }) } /// Extract a String from the input, only allowing exact /// matches for a String (no subclasses) fn exact_str(&'a self) -> ValResult> { - self.strict_str() + self.validate_str(true, false).and_then(|val_match| { + val_match + .require_exact() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::StringType, self)) + }) } - fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult> { - if ultra_strict { - self.ultra_strict_float() - } else if strict { - self.strict_float() - } else { - self.lax_float() - } - } - fn ultra_strict_float(&'a self) -> ValResult>; - fn strict_float(&'a self) -> ValResult>; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_float(&'a self) -> ValResult> { - self.strict_float() - } + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> { if strict { @@ -257,87 +209,25 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem { fn validate_iter(&self) -> ValResult; - fn validate_date(&self, strict: bool) -> ValResult { - if strict { - self.strict_date() - } else { - self.lax_date() - } - } - fn strict_date(&self) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_date(&self) -> ValResult { - self.strict_date() - } + fn validate_date(&self, strict: bool) -> ValResult>; fn validate_time( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_time(microseconds_overflow_behavior) - } else { - self.lax_time(microseconds_overflow_behavior) - } - } - fn strict_time( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_time( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_time(microseconds_overflow_behavior) - } + ) -> ValResult>; fn validate_datetime( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_datetime(microseconds_overflow_behavior) - } else { - self.lax_datetime(microseconds_overflow_behavior) - } - } - fn strict_datetime( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_datetime( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_datetime(microseconds_overflow_behavior) - } + ) -> ValResult>; fn validate_timedelta( &self, strict: bool, microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if strict { - self.strict_timedelta(microseconds_overflow_behavior) - } else { - self.lax_timedelta(microseconds_overflow_behavior) - } - } - fn strict_timedelta( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult; - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn lax_timedelta( - &self, - microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - self.strict_timedelta(microseconds_overflow_behavior) - } + ) -> ValResult>; } /// The problem to solve here is that iterating a `StringMapping` returns an owned diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 59c120cd0..0411a25d6 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -13,6 +13,7 @@ use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; +use super::return_enums::ValidationMatch; use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, @@ -84,97 +85,72 @@ impl<'a> Input<'a> for JsonValue { } } - fn strict_str(&'a self) -> ValResult> { + fn exact_str(&'a self) -> ValResult> { match self { JsonValue::Str(s) => Ok(s.as_str().into()), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { + + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>> { + // Justification for `strict` instead of `exact` is that in JSON strings can also + // represent other datatypes such as UUID and date more exactly, so string is a + // converting input + // TODO: in V3 we may want to make JSON str always win if in union, for consistency, + // see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501 match self { - JsonValue::Str(s) => Ok(s.as_str().into()), - JsonValue::Int(i) if coerce_numbers_to_str => Ok(i.to_string().into()), - JsonValue::BigInt(b) if coerce_numbers_to_str => Ok(b.to_string().into()), - JsonValue::Float(f) if coerce_numbers_to_str => Ok(f.to_string().into()), + JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_str().into())), + JsonValue::Int(i) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(i.to_string().into())), + JsonValue::BigInt(b) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(b.to_string().into())), + JsonValue::Float(f) if !strict && coerce_numbers_to_str => Ok(ValidationMatch::lax(f.to_string().into())), _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn validate_bytes(&'a self, _strict: bool) -> ValResult> { + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { match self { - JsonValue::Str(s) => Ok(s.as_bytes().into()), + JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())), _ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_bytes(&'a self) -> ValResult> { - self.validate_bytes(false) - } - fn strict_bool(&self) -> ValResult { + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { match self { - JsonValue::Bool(b) => Ok(*b), - _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), - } - } - fn lax_bool(&self) -> ValResult { - match self { - JsonValue::Bool(b) => Ok(*b), - JsonValue::Str(s) => str_as_bool(self, s), - JsonValue::Int(int) => int_as_bool(self, *int), - JsonValue::Float(float) => match float_as_int(self, *float) { + JsonValue::Bool(b) => Ok(ValidationMatch::exact(*b)), + JsonValue::Str(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax), + JsonValue::Int(int) if !strict => int_as_bool(self, *int).map(ValidationMatch::lax), + JsonValue::Float(float) if !strict => match float_as_int(self, *float) { Ok(int) => int .as_bool() - .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), + .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)) + .map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), }, _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } - fn strict_int(&'a self) -> ValResult> { - match self { - JsonValue::Int(i) => Ok(EitherInt::I64(*i)), - JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), - } - } - fn lax_int(&'a self) -> ValResult> { + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - JsonValue::Bool(b) => match *b { - true => Ok(EitherInt::I64(1)), - false => Ok(EitherInt::I64(0)), - }, - JsonValue::Int(i) => Ok(EitherInt::I64(*i)), - JsonValue::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), - JsonValue::Float(f) => float_as_int(self, *f), - JsonValue::Str(str) => str_as_int(self, str), + JsonValue::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))), + JsonValue::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))), + JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherInt::I64((*b).into()))), + JsonValue::Float(f) if !strict => float_as_int(self, *f).map(ValidationMatch::lax), + JsonValue::Str(str) if !strict => str_as_int(self, str).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), - _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), - } - } - fn strict_float(&'a self) -> ValResult> { - match self { - JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), - JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), - _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), - } - } - fn lax_float(&'a self) -> ValResult> { - match self { - JsonValue::Bool(b) => match *b { - true => Ok(EitherFloat::F64(1.0)), - false => Ok(EitherFloat::F64(0.0)), - }, - JsonValue::Float(f) => Ok(EitherFloat::F64(*f)), - JsonValue::Int(i) => Ok(EitherFloat::F64(*i as f64)), - JsonValue::Str(str) => str_as_float(self, str), + JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))), + JsonValue::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))), + JsonValue::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherFloat::F64(if *b { 1.0 } else { 0.0 }))), + JsonValue::Str(str) if !strict => str_as_float(self, str).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } @@ -270,34 +246,24 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_date(&self, _strict: bool) -> ValResult { + fn validate_date(&self, _strict: bool) -> ValResult> { match self { - JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()), + JsonValue::Str(v) => bytes_as_date(self, v.as_bytes()).map(ValidationMatch::strict), _ => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } - // NO custom `lax_date` implementation, if strict_date fails, the validator will fallback to lax_datetime - // then check there's no remainder - #[cfg_attr(has_coverage_attribute, coverage(off))] - fn strict_date(&self) -> ValResult { - self.validate_date(false) - } - - fn strict_time( + fn validate_time( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), - } - } - fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { - match self { - JsonValue::Str(v) => bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior), - JsonValue::Int(v) => int_as_time(self, *v, 0), - JsonValue::Float(v) => float_as_time(self, *v), - JsonValue::BigInt(_) => Err(ValError::new( + JsonValue::Str(v) => { + bytes_as_time(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => int_as_time(self, *v, 0).map(ValidationMatch::lax), + JsonValue::Float(v) if !strict => float_as_time(self, *v).map(ValidationMatch::lax), + JsonValue::BigInt(_) if !strict => Err(ValError::new( ErrorType::TimeParsing { error: Cow::Borrowed( speedate::ParseError::TimeTooLarge @@ -312,44 +278,36 @@ impl<'a> Input<'a> for JsonValue { } } - fn strict_datetime( - &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - match self { - JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), - } - } - fn lax_datetime( + fn validate_datetime( &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + strict: bool, + microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { match self { - JsonValue::Str(v) => bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior), - JsonValue::Int(v) => int_as_datetime(self, *v, 0), - JsonValue::Float(v) => float_as_datetime(self, *v), + JsonValue::Str(v) => { + bytes_as_datetime(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => int_as_datetime(self, *v, 0).map(ValidationMatch::lax), + JsonValue::Float(v) if !strict => float_as_datetime(self, *v).map(ValidationMatch::lax), _ => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } - fn strict_timedelta( - &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - match self { - JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), - } - } - fn lax_timedelta( + fn validate_timedelta( &self, - microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + strict: bool, + microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { match self { - JsonValue::Str(v) => bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior), - JsonValue::Int(v) => Ok(int_as_duration(self, *v)?.into()), - JsonValue::Float(v) => Ok(float_as_duration(self, *v)?.into()), + JsonValue::Str(v) => { + bytes_as_timedelta(self, v.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::strict) + } + JsonValue::Int(v) if !strict => { + int_as_duration(self, *v).map(|duration| ValidationMatch::lax(duration.into())) + } + JsonValue::Float(v) if !strict => { + float_as_duration(self, *v).map(|duration| ValidationMatch::lax(duration.into())) + } _ => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } @@ -399,30 +357,36 @@ impl<'a> Input<'a> for String { JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) } - fn strict_str(&'a self) -> ValResult> { - Ok(self.as_str().into()) + fn validate_str( + &'a self, + _strict: bool, + _coerce_numbers_to_str: bool, + ) -> ValResult>> { + // Justification for `strict` instead of `exact` is that in JSON strings can also + // represent other datatypes such as UUID and date more exactly, so string is a + // converting input + // TODO: in V3 we may want to make JSON str always win if in union, for consistency, + // see https://github.com/pydantic/pydantic-core/pull/867#discussion_r1386582501 + Ok(ValidationMatch::strict(self.as_str().into())) } - fn strict_bytes(&'a self) -> ValResult> { - Ok(self.as_bytes().into()) + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { + Ok(ValidationMatch::strict(self.as_bytes().into())) } - fn strict_bool(&self) -> ValResult { - str_as_bool(self, self) + fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { + str_as_bool(self, self).map(ValidationMatch::lax) } - fn strict_int(&'a self) -> ValResult> { + fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self.parse() { - Ok(i) => Ok(EitherInt::I64(i)), + Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { - self.strict_float() - } - fn strict_float(&'a self) -> ValResult> { - str_as_float(self, self) + fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + str_as_float(self, self).map(ValidationMatch::lax) } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { @@ -462,29 +426,32 @@ impl<'a> Input<'a> for String { Ok(string_to_vec(self).into()) } - fn strict_date(&self) -> ValResult { - bytes_as_date(self, self.as_bytes()) + fn validate_date(&self, _strict: bool) -> ValResult> { + bytes_as_date(self, self.as_bytes()).map(ValidationMatch::lax) } - fn strict_time( + fn validate_time( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_time(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn strict_datetime( + fn validate_datetime( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_datetime(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } - fn strict_timedelta( + fn validate_timedelta( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior) + ) -> ValResult> { + bytes_as_timedelta(self, self.as_bytes(), microseconds_overflow_behavior).map(ValidationMatch::lax) } } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index de59ebce0..90d2c2a8b 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -16,6 +16,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; use crate::validators::decimal::{create_decimal, get_decimal_type}; +use crate::validators::Exactness; use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl}; use super::datetime::{ @@ -23,6 +24,7 @@ use super::datetime::{ float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; +use super::return_enums::ValidationMatch; use super::shared::{ decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int, @@ -203,194 +205,185 @@ impl<'a> Input<'a> for PyAny { JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e)) } - fn strict_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { - Ok(py_str.into()) + fn validate_str( + &'a self, + strict: bool, + coerce_numbers_to_str: bool, + ) -> ValResult>> { + if let Ok(py_str) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(py_str.into())); } else if let Ok(py_str) = self.downcast::() { // force to a rust string to make sure behavior is consistent whether or not we go via a // rust string in StrConstrainedValidator - e.g. to_lower - Ok(py_string_str(py_str)?.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::StringType, self)) + return Ok(ValidationMatch::strict(py_string_str(py_str)?.into())); + } + + 'lax: { + if !strict { + return if let Ok(bytes) = self.downcast::() { + match from_utf8(bytes.as_bytes()) { + Ok(str) => Ok(str.into()), + Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), + } + } else if let Ok(py_byte_array) = self.downcast::() { + // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, + // and we immediately copy the bytes into a new Python string + match from_utf8(unsafe { py_byte_array.as_bytes() }) { + // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the + // final output needs to be Python anyway. + Ok(s) => Ok(PyString::new(self.py(), s).into()), + Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), + } + } else if coerce_numbers_to_str && !PyBool::is_exact_type_of(self) && { + let py = self.py(); + let decimal_type: Py = get_decimal_type(py); + + // only allow int, float, and decimal (not bool) + self.is_instance_of::() + || self.is_instance_of::() + || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() + } { + Ok(self.str()?.into()) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(enum_val.str()?.into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } + + Err(ValError::new(ErrorTypeDefaults::StringType, self)) } - fn lax_str(&'a self, coerce_numbers_to_str: bool) -> ValResult> { + fn exact_str(&'a self) -> ValResult> { if let Ok(py_str) = ::try_from_exact(self) { - Ok(py_str.into()) - } else if let Ok(py_str) = self.downcast::() { - // force to a rust string to make sure behaviour is consistent whether or not we go via a - // rust string in StrConstrainedValidator - e.g. to_lower - Ok(py_string_str(py_str)?.into()) - } else if let Ok(bytes) = self.downcast::() { - let str = match from_utf8(bytes.as_bytes()) { - Ok(s) => s, - Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), - }; - Ok(str.into()) - } else if let Ok(py_byte_array) = self.downcast::() { - // Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated, - // and we immediately copy the bytes into a new Python string - let s = match from_utf8(unsafe { py_byte_array.as_bytes() }) { - // Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the - // final output needs to be Python anyway. - Ok(s) => PyString::new(self.py(), s), - Err(_) => return Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)), - }; - Ok(s.into()) - } else if coerce_numbers_to_str && !PyBool::is_exact_type_of(self) && { - let py = self.py(); - let decimal_type: Py = get_decimal_type(py); - - // only allow int, float, and decimal (not bool) - self.is_instance_of::() - || self.is_instance_of::() - || self.is_instance(decimal_type.as_ref(py)).unwrap_or_default() - } { - Ok(self.str()?.into()) - } else if let Some(enum_val) = maybe_as_enum(self) { - Ok(enum_val.str()?.into()) + Ok(EitherString::Py(py_str)) } else { - Err(ValError::new(ErrorTypeDefaults::StringType, self)) + Err(ValError::new(ErrorTypeDefaults::IntType, self)) } } - fn strict_bytes(&'a self) -> ValResult> { - if let Ok(py_bytes) = self.downcast::() { - Ok(py_bytes.into()) + fn exact_int(&'a self) -> ValResult> { + if PyInt::is_exact_type_of(self) { + Ok(EitherInt::Py(self)) } else { - Err(ValError::new(ErrorTypeDefaults::BytesType, self)) + Err(ValError::new(ErrorTypeDefaults::IntType, self)) } } - fn lax_bytes(&'a self) -> ValResult> { - if let Ok(py_bytes) = self.downcast::() { - Ok(py_bytes.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - Ok(str.as_bytes().into()) - } else if let Ok(py_byte_array) = self.downcast::() { - Ok(py_byte_array.to_vec().into()) - } else { - Err(ValError::new(ErrorTypeDefaults::BytesType, self)) + fn validate_bytes(&'a self, strict: bool) -> ValResult>> { + if let Ok(py_bytes) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(py_bytes.into())); + } else if let Ok(py_bytes) = self.downcast::() { + return Ok(ValidationMatch::strict(py_bytes.into())); } - } - fn strict_bool(&self) -> ValResult { - if let Ok(bool) = self.downcast::() { - Ok(bool.is_true()) - } else { - Err(ValError::new(ErrorTypeDefaults::BoolType, self)) + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + Ok(str.as_bytes().into()) + } else if let Ok(py_byte_array) = self.downcast::() { + Ok(py_byte_array.to_vec().into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } + + Err(ValError::new(ErrorTypeDefaults::BytesType, self)) } - fn lax_bool(&self) -> ValResult { + fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { if let Ok(bool) = self.downcast::() { - Ok(bool.is_true()) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { - str_as_bool(self, &cow_str) - } else if let Ok(int) = extract_i64(self) { - int_as_bool(self, int) - } else if let Ok(float) = self.extract::() { - match float_as_int(self, float) { - Ok(int) => int - .as_bool() - .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)), - _ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), + return Ok(ValidationMatch::exact(bool.is_true())); + } + + if !strict { + if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { + return str_as_bool(self, &cow_str).map(ValidationMatch::lax); + } else if let Ok(int) = extract_i64(self) { + return int_as_bool(self, int).map(ValidationMatch::lax); + } else if let Ok(float) = self.extract::() { + if let Ok(int) = float_as_int(self, float) { + return int + .as_bool() + .ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)) + .map(ValidationMatch::lax); + }; } - } else { - Err(ValError::new(ErrorTypeDefaults::BoolType, self)) } + + Err(ValError::new(ErrorTypeDefaults::BoolType, self)) } - fn strict_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else if PyInt::is_type_of(self) { + fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + if self.is_exact_instance_of::() { + return Ok(ValidationMatch::exact(EitherInt::Py(self))); + } else if self.is_instance_of::() { // bools are a subclass of int, so check for bool type in this specific case - if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) + let exactness = if self.is_instance_of::() { + if strict { + return Err(ValError::new(ErrorTypeDefaults::IntType, self)); + } + Exactness::Lax } else { - // force to an int to upcast to a pure python int - EitherInt::upcast(self) + Exactness::Strict + }; + + // force to an int to upcast to a pure python int + return EitherInt::upcast(self).map(|either_int| ValidationMatch::new(either_int, exactness)); + } + + 'lax: { + if !strict { + return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { + str_as_int(self, &cow_str) + } else if self.is_exact_instance_of::() { + float_as_int(self, self.extract::()?) + } else if let Ok(decimal) = self.strict_decimal(self.py()) { + decimal_as_int(self.py(), self, decimal) + } else if let Ok(float) = self.extract::() { + float_as_int(self, float) + } else if let Some(enum_val) = maybe_as_enum(self) { + Ok(EitherInt::Py(enum_val)) + } else { + break 'lax; + } + .map(ValidationMatch::lax); } - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) } - } - fn lax_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? { - // Try strings before subclasses of int as that will be far more common - str_as_int(self, &cow_str) - } else if PyInt::is_type_of(self) { - // force to an int to upcast to a pure python int to maintain current behaviour - EitherInt::upcast(self) - } else if PyFloat::is_exact_type_of(self) { - float_as_int(self, self.extract::()?) - } else if let Ok(decimal) = self.strict_decimal(self.py()) { - decimal_as_int(self.py(), self, decimal) - } else if let Ok(float) = self.extract::() { - float_as_int(self, float) - } else if let Some(enum_val) = maybe_as_enum(self) { - Ok(EitherInt::Py(enum_val)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } + Err(ValError::new(ErrorTypeDefaults::IntType, self)) } - fn exact_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) + fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + if let Ok(float) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(EitherFloat::Py(float))); } - } - fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = PyString::try_from_exact(self) { - Ok(EitherString::Py(py_str)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) + if !strict { + if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { + // checking for bytes and string is fast, so do this before isinstance(float) + return str_as_float(self, &cow_str).map(ValidationMatch::lax); + } } - } - fn ultra_strict_float(&'a self) -> ValResult> { - if self.is_instance_of::() { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } else if let Ok(float) = self.downcast::() { - Ok(EitherFloat::Py(float)) - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } - } - fn strict_float(&'a self) -> ValResult> { - if let Ok(py_float) = self.downcast_exact::() { - Ok(EitherFloat::Py(py_float)) - } else if let Ok(float) = self.extract::() { - // bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case - if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) + if let Ok(float) = self.extract::() { + let exactness = if self.is_instance_of::() { + if strict { + return Err(ValError::new(ErrorTypeDefaults::FloatType, self)); + } + Exactness::Lax } else { - Ok(EitherFloat::F64(float)) - } - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) + Exactness::Strict + }; + return Ok(ValidationMatch::new(EitherFloat::F64(float), exactness)); } - } - fn lax_float(&'a self) -> ValResult> { - if let Ok(py_float) = self.downcast_exact() { - Ok(EitherFloat::Py(py_float)) - } else if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? { - str_as_float(self, &cow_str) - } else if let Ok(float) = self.extract::() { - Ok(EitherFloat::F64(float)) - } else { - Err(ValError::new(ErrorTypeDefaults::FloatType, self)) - } + Err(ValError::new(ErrorTypeDefaults::FloatType, self)) } fn strict_decimal(&'a self, py: Python<'a>) -> ValResult<&'a PyAny> { @@ -607,128 +600,136 @@ impl<'a> Input<'a> for PyAny { } } - fn strict_date(&self) -> ValResult { - if PyDateTime::is_type_of(self) { - // have to check if it's a datetime first, otherwise the line below converts to a date - Err(ValError::new(ErrorTypeDefaults::DateType, self)) - } else if let Ok(date) = self.downcast::() { - Ok(date.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::DateType, self)) - } - } - - fn lax_date(&self) -> ValResult { - if PyDateTime::is_type_of(self) { + fn validate_date(&self, strict: bool) -> ValResult> { + if let Ok(date) = self.downcast_exact::() { + Ok(ValidationMatch::exact(date.into())) + } else if PyDateTime::is_type_of(self) { // have to check if it's a datetime first, otherwise the line below converts to a date // even if we later try coercion from a datetime, we don't want to return a datetime now Err(ValError::new(ErrorTypeDefaults::DateType, self)) } else if let Ok(date) = self.downcast::() { - Ok(date.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_date(self, str.as_bytes()) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_date(self, py_bytes.as_bytes()) + Ok(ValidationMatch::strict(date.into())) + } else if let Some(bytes) = { + if strict { + None + } else if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + Some(str.as_bytes()) + } else if let Ok(py_bytes) = self.downcast::() { + Some(py_bytes.as_bytes()) + } else { + None + } + } { + bytes_as_date(self, bytes).map(ValidationMatch::lax) } else { Err(ValError::new(ErrorTypeDefaults::DateType, self)) } } - fn strict_time( + fn validate_time( &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(time) = self.downcast::() { - Ok(time.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } - } - - fn lax_time(&self, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior) -> ValResult { - if let Ok(time) = self.downcast::() { - Ok(time.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_time(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } else if let Ok(int) = extract_i64(self) { - int_as_time(self, int, 0) - } else if let Ok(float) = self.extract::() { - float_as_time(self, float) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeType, self)) + strict: bool, + microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, + ) -> ValResult> { + if let Ok(time) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(time.into())); + } else if let Ok(time) = self.downcast::() { + return Ok(ValidationMatch::strict(time.into())); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_time(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if PyBool::is_exact_type_of(self) { + Err(ValError::new(ErrorTypeDefaults::TimeType, self)) + } else if let Ok(int) = extract_i64(self) { + int_as_time(self, int, 0) + } else if let Ok(float) = self.extract::() { + float_as_time(self, float) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } - } - fn strict_datetime( - &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(dt) = self.downcast::() { - Ok(dt.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } + Err(ValError::new(ErrorTypeDefaults::TimeType, self)) } - fn lax_datetime( + fn validate_datetime( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(dt) = self.downcast::() { - Ok(dt.into()) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if PyBool::is_exact_type_of(self) { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } else if let Ok(int) = extract_i64(self) { - int_as_datetime(self, int, 0) - } else if let Ok(float) = self.extract::() { - float_as_datetime(self, float) - } else if let Ok(date) = self.downcast::() { - Ok(date_as_datetime(date)?) - } else { - Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) + ) -> ValResult> { + if let Ok(dt) = self.downcast_exact::() { + return Ok(ValidationMatch::exact(dt.into())); + } else if let Ok(dt) = self.downcast::() { + return Ok(ValidationMatch::strict(dt.into())); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_datetime(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if PyBool::is_exact_type_of(self) { + Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) + } else if let Ok(int) = extract_i64(self) { + int_as_datetime(self, int, 0) + } else if let Ok(float) = self.extract::() { + float_as_datetime(self, float) + } else if let Ok(date) = self.downcast::() { + Ok(date_as_datetime(date)?) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } - } - fn strict_timedelta( - &self, - _microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { - if let Ok(either_dt) = EitherTimedelta::try_from(self) { - Ok(either_dt) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) - } + Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) } - fn lax_timedelta( + fn validate_timedelta( &self, + strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { if let Ok(either_dt) = EitherTimedelta::try_from(self) { - Ok(either_dt) - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(py_bytes) = self.downcast::() { - bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(int) = extract_i64(self) { - Ok(int_as_duration(self, int)?.into()) - } else if let Ok(float) = self.extract::() { - Ok(float_as_duration(self, float)?.into()) - } else { - Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) + let exactness = if matches!(either_dt, EitherTimedelta::PyExact(_)) { + Exactness::Exact + } else { + Exactness::Strict + }; + return Ok(ValidationMatch::new(either_dt, exactness)); + } + + 'lax: { + if !strict { + return if let Ok(py_str) = self.downcast::() { + let str = py_string_str(py_str)?; + bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(py_bytes) = self.downcast::() { + bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) + } else if let Ok(int) = extract_i64(self) { + Ok(int_as_duration(self, int)?.into()) + } else if let Ok(float) = self.extract::() { + Ok(float_as_duration(self, float)?.into()) + } else { + break 'lax; + } + .map(ValidationMatch::lax); + } } + + Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)) } } diff --git a/src/input/input_string.rs b/src/input/input_string.rs index b84908edf..e27ef6461 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -15,7 +15,7 @@ use super::datetime::{ use super::shared::{map_json_err, str_as_bool, str_as_float}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, - GenericIterator, GenericMapping, Input, + GenericIterator, GenericMapping, Input, ValidationMatch, }; #[derive(Debug)] @@ -96,54 +96,44 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn strict_str(&'a self) -> ValResult> { + fn validate_str( + &'a self, + _strict: bool, + _coerce_numbers_to_str: bool, + ) -> ValResult>> { match self { - Self::String(s) => Ok((*s).into()), + Self::String(s) => Ok(ValidationMatch::strict((*s).into())), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::StringType, self)), } } - fn strict_bytes(&'a self) -> ValResult> { + fn validate_bytes(&'a self, _strict: bool) -> ValResult>> { match self { - Self::String(s) => py_string_str(s).map(|b| b.as_bytes().into()), + Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), } } - fn lax_bytes(&'a self) -> ValResult> { + fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { match self { - Self::String(s) => { - let str = py_string_str(s)?; - Ok(str.as_bytes().into()) - } - Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)), - } - } - - fn strict_bool(&self) -> ValResult { - match self { - Self::String(s) => str_as_bool(self, py_string_str(s)?), + Self::String(s) => str_as_bool(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } - fn strict_int(&'a self) -> ValResult> { + fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self { Self::String(s) => match py_string_str(s)?.parse() { - Ok(i) => Ok(EitherInt::I64(i)), + Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), }, Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } - fn ultra_strict_float(&'a self) -> ValResult> { - self.strict_float() - } - - fn strict_float(&'a self) -> ValResult> { + fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { match self { - Self::String(s) => str_as_float(self, py_string_str(s)?), + Self::String(s) => str_as_float(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), } } @@ -186,39 +176,45 @@ impl<'a> Input<'a> for StringMapping<'a> { Err(ValError::new(ErrorTypeDefaults::IterableType, self)) } - fn strict_date(&self) -> ValResult { + fn validate_date(&self, _strict: bool) -> ValResult> { match self { - Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()), + Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)), } } - fn strict_time( + fn validate_time( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeType, self)), } } - fn strict_datetime( + fn validate_datetime( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)), } } - fn strict_timedelta( + fn validate_timedelta( &self, + _strict: bool, microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, - ) -> ValResult { + ) -> ValResult> { match self { - Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior), + Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior) + .map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)), } } diff --git a/src/input/mod.rs b/src/input/mod.rs index 13c835f83..d7ca0a5bf 100644 --- a/src/input/mod.rs +++ b/src/input/mod.rs @@ -20,7 +20,7 @@ pub(crate) use input_string::StringMapping; pub(crate) use return_enums::{ py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator, - MappingGenericIterator, PyArgs, StringMappingGenericIterator, + MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch, }; // Defined here as it's not exported by pyo3 diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 412842a13..56c7098df 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -24,11 +24,44 @@ use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult}; use crate::tools::py_err; -use crate::validators::{CombinedValidator, ValidationState, Validator}; +use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; use super::input_string::StringMapping; use super::{py_error_on_minusone, Input}; +pub struct ValidationMatch(T, Exactness); + +impl ValidationMatch { + pub const fn new(value: T, exactness: Exactness) -> Self { + Self(value, exactness) + } + + pub const fn exact(value: T) -> Self { + Self(value, Exactness::Exact) + } + + pub const fn strict(value: T) -> Self { + Self(value, Exactness::Strict) + } + + pub const fn lax(value: T) -> Self { + Self(value, Exactness::Lax) + } + + pub fn require_exact(self) -> Option { + (self.1 == Exactness::Exact).then_some(self.0) + } + + pub fn unpack(self, state: &mut ValidationState) -> T { + state.floor_exactness(self.1); + self.0 + } + + pub fn into_inner(self) -> T { + self.0 + } +} + /// Container for all the collections (sized iterable containers) types, which /// can mostly be converted to each other in lax mode. /// This mostly matches python's definition of `Collection`. diff --git a/src/validators/any.rs b/src/validators/any.rs index 625eb4adf..2fad89091 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -4,7 +4,9 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::Input; -use super::{validation_state::ValidationState, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{ + validation_state::Exactness, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator, +}; /// This might seem useless, but it's useful in DictValidator to avoid Option a lot #[derive(Debug, Clone)] @@ -29,20 +31,14 @@ impl Validator for AnyValidator { &self, py: Python<'data>, input: &'data impl Input<'data>, - _state: &mut ValidationState, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + // in a union, Any should be preferred to doing lax coercions + state.floor_exactness(Exactness::Strict); Ok(input.to_object(py)) } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 748b13338..7ae65d579 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{GenericArguments, Input}; +use crate::input::{GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -282,7 +282,7 @@ impl Validator for ArgumentsValidator { if let Some(kwargs) = $args.kwargs { if kwargs.len() > used_kwargs.len() { for (raw_key, value) in kwargs.iter() { - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -332,26 +332,7 @@ impl Validator for ArgumentsValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.parameters - .iter() - .any(|p| p.validator.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - self.parameters - .iter() - .try_for_each(|parameter| parameter.validator.complete())?; - if let Some(v) = &self.var_args_validator { - v.complete()?; - } - if let Some(v) = &self.var_kwargs_validator { - v.complete()?; - }; - Ok(()) - } } diff --git a/src/validators/bool.rs b/src/validators/bool.rs index 3a38cf3e5..bcd48e991 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -38,19 +38,12 @@ impl Validator for BoolValidator { ) -> ValResult<'data, PyObject> { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? - let strict = state.strict_or(self.strict); - Ok(input.validate_bool(strict)?.into_py(py)) - } - - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict + input + .validate_bool(state.strict_or(self.strict)) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index fb90187ac..78a8acb24 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -46,21 +46,14 @@ impl Validator for BytesValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; - Ok(either_bytes.into_py(py)) - } - - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict + input + .validate_bytes(state.strict_or(self.strict)) + .map(|m| m.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -79,7 +72,7 @@ impl Validator for BytesConstrainedValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_bytes = input.validate_bytes(state.strict_or(self.strict))?; + let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state); let len = either_bytes.len()?; if let Some(min_length) = self.min_length { @@ -104,21 +97,12 @@ impl Validator for BytesConstrainedValidator { )); } } - Ok(either_bytes.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-bytes" } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl BytesConstrainedValidator { diff --git a/src/validators/call.rs b/src/validators/call.rs index eca1f0206..e0649aa53 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -98,24 +98,7 @@ impl Validator for CallValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if let Some(return_validator) = &self.return_validator { - if return_validator.different_strict_behavior(ultra_strict) { - return true; - } - } - self.arguments_validator.different_strict_behavior(ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.arguments_validator.complete()?; - match &self.return_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } diff --git a/src/validators/callable.rs b/src/validators/callable.rs index 83eb37cbe..3075e182e 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -4,6 +4,7 @@ use pyo3::types::PyDict; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; +use super::validation_state::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -28,23 +29,16 @@ impl Validator for CallableValidator { &self, py: Python<'data>, input: &'data impl Input<'data>, - _state: &mut ValidationState, + state: &mut ValidationState, ) -> ValResult<'data, PyObject> { + state.floor_exactness(Exactness::Lax); match input.callable() { true => Ok(input.to_object(py)), false => Err(ValError::new(ErrorTypeDefaults::CallableType, input)), } } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/chain.rs b/src/validators/chain.rs index c0f356fa0..d8da86e30 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -83,15 +83,7 @@ impl Validator for ChainValidator { steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.steps.iter().any(|v| v.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.steps.iter().try_for_each(CombinedValidator::complete) - } } diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 0d9931c62..4ea31aa5a 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -99,15 +99,7 @@ impl Validator for CustomErrorValidator { .map_err(|_| self.custom_error.as_val_error(input)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.validator.different_strict_behavior(ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index d93441ce0..1646f5ea6 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -8,13 +8,14 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{BorrowInput, GenericArguments, Input}; +use crate::input::{BorrowInput, GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; use super::model::{create_class, force_setattr, Revalidate}; +use super::validation_state::Exactness; use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; @@ -281,7 +282,7 @@ impl Validator for DataclassArgsValidator { if let Some(kwargs) = $args.kwargs { if kwargs.len() != used_keys.len() { for (raw_key, value) in kwargs.iter() { - match raw_key.strict_str() { + match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(either_str) => { if !used_keys.contains(either_str.as_cow()?.as_ref()) { // Unknown / extra field @@ -438,19 +439,9 @@ impl Validator for DataclassArgsValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { &self.validator_name } - - fn complete(&self) -> PyResult<()> { - self.fields.iter().try_for_each(|field| field.validator.complete()) - } } #[derive(Debug)] @@ -554,6 +545,7 @@ impl Validator for DataclassValidator { )) } else { let val_output = self.validator.validate(py, input, state)?; + state.floor_exactness(Exactness::Strict); let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -594,21 +586,9 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } impl DataclassValidator { diff --git a/src/validators/date.rs b/src/validators/date.rs index 3549f66f0..7c79101f4 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -11,6 +11,7 @@ use crate::input::{EitherDate, Input}; use crate::tools::SchemaDict; use crate::validators::datetime::{NowConstraint, NowOp}; +use super::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -46,9 +47,12 @@ impl Validator for DateValidator { ) -> ValResult<'data, PyObject> { let strict = state.strict_or(self.strict); let date = match input.validate_date(strict) { - Ok(date) => date, + Ok(val_match) => val_match.unpack(state), // if the error was a parsing error, in lax mode we allow datetimes at midnight - Err(line_errors @ ValError::LineErrors(..)) if !strict => date_from_datetime(input)?.ok_or(line_errors)?, + Err(line_errors @ ValError::LineErrors(..)) if !strict => { + state.floor_exactness(Exactness::Lax); + date_from_datetime(input)?.ok_or(line_errors)? + } Err(otherwise) => return Err(otherwise), }; if let Some(constraints) = &self.constraints { @@ -96,17 +100,9 @@ impl Validator for DateValidator { Ok(date.try_into_py(py)?) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } /// In lax mode, if the input is not a date, we try parsing the input as a datetime, then check it is an @@ -115,7 +111,7 @@ impl Validator for DateValidator { /// Ok(None) means that this is not relevant to dates (the input was not a datetime nor a string) fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result>, ValError<'data>> { let either_dt = match input.validate_datetime(false, speedate::MicrosecondsPrecisionOverflowBehavior::Truncate) { - Ok(dt) => dt, + Ok(val_match) => val_match.into_inner(), // if the error was a parsing error, update the error type from DatetimeParsing to DateFromDatetimeParsing // and return it Err(ValError::LineErrors(mut line_errors)) => { diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 5f1fc8bef..edbd399e7 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -65,7 +65,9 @@ impl Validator for DateTimeValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let strict = state.strict_or(self.strict); - let datetime = input.validate_datetime(strict, self.microseconds_precision)?; + let datetime = input + .validate_datetime(strict, self.microseconds_precision)? + .unpack(state); if let Some(constraints) = &self.constraints { // if we get an error from as_speedate, it's probably because the input datetime was invalid // specifically had an invalid tzinfo, hence here we return a validation error @@ -125,17 +127,9 @@ impl Validator for DateTimeValidator { Ok(datetime.try_into_py(py)?) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index eb3141c31..b9435f046 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -264,17 +264,9 @@ impl Validator for DecimalValidator { Ok(decimal.into()) } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - true - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } pub(crate) fn create_decimal<'a>( diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 16aea8cd4..979278bb9 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -1,7 +1,3 @@ -use std::cell::RefCell; - -use ahash::HashSet; -use ahash::HashSetExt; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; @@ -123,32 +119,7 @@ impl Validator for DefinitionRefValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - thread_local! { - static RECURSION_SET: RefCell>> = RefCell::new(None); - } - - let id = self as *const _ as usize; - // have to unwrap here, because we can't return an error from this function, should be okay - let validator: &CombinedValidator = self.definition.get().unwrap(); - if RECURSION_SET.with( - |set: &RefCell>>| { - set.borrow_mut().get_or_insert_with(HashSet::new).insert(id) - }, - ) { - let different_strict_behavior = validator.different_strict_behavior(ultra_strict); - RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).remove(&id)); - different_strict_behavior - } else { - false - } - } - fn get_name(&self) -> &str { self.definition.get_or_init_name(|v| v.get_name().into()) } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index c7df345ed..3ac284b2a 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -80,6 +80,7 @@ impl Validator for DictValidator { self.validate_generic_mapping(py, input, DictGenericIterator::new(py_dict)?, state) } GenericMapping::PyMapping(mapping) => { + state.floor_exactness(super::Exactness::Lax); self.validate_generic_mapping(py, input, MappingGenericIterator::new(mapping)?, state) } GenericMapping::StringMapping(dict) => { @@ -92,22 +93,9 @@ impl Validator for DictValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.key_validator.different_strict_behavior(true) || self.value_validator.different_strict_behavior(true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.key_validator.complete()?; - self.value_validator.complete() - } } impl DictValidator { diff --git a/src/validators/float.rs b/src/validators/float.rs index 646d8f4d8..b72ffafc0 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -70,25 +70,16 @@ impl Validator for FloatValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = state.strict_or(self.strict); - let either_float = input.validate_float(strict, state.extra().ultra_strict)?; + let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); if !self.allow_inf_nan && !either_float.as_f64().is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } Ok(either_float.into_py(py)) } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - true - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -111,8 +102,7 @@ impl Validator for ConstrainedFloatValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let strict = state.strict_or(self.strict); - let either_float = input.validate_float(strict, state.extra().ultra_strict)?; + let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); let float: f64 = either_float.as_f64(); if !self.allow_inf_nan && !float.is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); @@ -177,17 +167,9 @@ impl Validator for ConstrainedFloatValidator { Ok(either_float.into_py(py)) } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - true - } - fn get_name(&self) -> &str { "constrained-float" } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl BuildValidator for ConstrainedFloatValidator { diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index 4b4cdcb6f..190a8672d 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -2,8 +2,9 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyFrozenSet}; use crate::errors::ValResult; -use crate::input::Input; +use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::min_length_check; use super::set::set_build; @@ -34,6 +35,12 @@ impl Validator for FrozenSetValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_frozenset(state.strict_or(self.strict))?; + let exactness = match &collection { + GenericIterable::FrozenSet(_) => Exactness::Exact, + GenericIterable::Set(_) | GenericIterable::JsonArray(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let f_set = PyFrozenSet::empty(py)?; collection.validate_to_set( py, @@ -48,19 +55,7 @@ impl Validator for FrozenSetValidator { Ok(f_set.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.item_validator.different_strict_behavior(true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.item_validator.complete() - } } diff --git a/src/validators/function.rs b/src/validators/function.rs index 66bbafbb9..4c5ad9c29 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -113,21 +113,9 @@ macro_rules! impl_validator { self._validate(validate, py, obj, state) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } }; } @@ -252,18 +240,9 @@ impl Validator for FunctionPlainValidator { r.map_err(|e| convert_err(py, e, input)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - // best guess, should we change this? - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } #[derive(Debug)] @@ -349,12 +328,10 @@ impl Validator for FunctionWrapValidator { self.validation_error_cause, ), }; - self._validate( - Py::new(py, handler)?.into_ref(py), - py, - input.to_object(py).into_ref(py), - state, - ) + let handler = Py::new(py, handler)?.into_ref(py); + let result = self._validate(handler, py, input.to_object(py).into_ref(py), state); + state.exactness = handler.borrow_mut().validator.exactness; + result } fn validate_assignment<'data>( @@ -380,21 +357,9 @@ impl Validator for FunctionWrapValidator { self._validate(Py::new(py, handler)?.into_ref(py), py, obj, state) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } #[pyclass(module = "pydantic_core._pydantic_core")] diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 111b5c101..94497d228 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -11,7 +11,9 @@ use crate::tools::SchemaDict; use crate::ValidationError; use super::list::get_items_schema; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, InputType, ValidationState, Validator}; +use super::{ + BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, Extra, InputType, ValidationState, Validator, +}; #[derive(Debug, Clone)] pub struct GeneratorValidator { @@ -86,24 +88,9 @@ impl Validator for GeneratorValidator { Ok(v_iterator.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if let Some(ref v) = self.item_validator { - v.different_strict_behavior(ultra_strict) - } else { - false - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - match &self.item_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } #[pyclass(module = "pydantic_core._pydantic_core")] @@ -226,6 +213,7 @@ pub struct InternalValidator { context: Option, self_instance: Option, recursion_guard: RecursionGuard, + pub(crate) exactness: Option, validation_mode: InputType, hide_input_in_errors: bool, validation_error_cause: bool, @@ -256,6 +244,7 @@ impl InternalValidator { context: extra.context.map(|d| d.into_py(py)), self_instance: extra.self_instance.map(|d| d.into_py(py)), recursion_guard: state.recursion_guard.clone(), + exactness: state.exactness, validation_mode: extra.input_type, hide_input_in_errors, validation_error_cause, @@ -274,13 +263,14 @@ impl InternalValidator { input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, - ultra_strict: false, from_attributes: self.from_attributes, context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; let mut state = ValidationState::new(extra, &mut self.recursion_guard); - self.validator + state.exactness = self.exactness; + let result = self + .validator .validate_assignment(py, model, field_name, field_value, &mut state) .map_err(|e| { ValidationError::from_val_error( @@ -292,7 +282,9 @@ impl InternalValidator { self.hide_input_in_errors, self.validation_error_cause, ) - }) + }); + self.exactness = state.exactness; + result } pub fn validate<'data>( @@ -305,13 +297,13 @@ impl InternalValidator { input_type: self.validation_mode, data: self.data.as_ref().map(|data| data.as_ref(py)), strict: self.strict, - ultra_strict: false, from_attributes: self.from_attributes, context: self.context.as_ref().map(|data| data.as_ref(py)), self_instance: self.self_instance.as_ref().map(|data| data.as_ref(py)), }; let mut state = ValidationState::new(extra, &mut self.recursion_guard); - self.validator.validate(py, input, &mut state).map_err(|e| { + state.exactness = self.exactness; + let result = self.validator.validate(py, input, &mut state).map_err(|e| { ValidationError::from_val_error( py, self.name.to_object(py), @@ -321,7 +313,9 @@ impl InternalValidator { self.hide_input_in_errors, self.validation_error_cause, ) - }) + }); + self.exactness = state.exactness; + result } } diff --git a/src/validators/int.rs b/src/validators/int.rs index 1a807d1ef..dabfb5115 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -8,8 +8,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::{Input, Int}; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IntValidator { @@ -50,21 +49,14 @@ impl Validator for IntValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_int = input.validate_int(state.strict_or(self.strict))?; - Ok(either_int.into_py(py)) - } - - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict + input + .validate_int(state.strict_or(self.strict)) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } #[derive(Debug, Clone)] @@ -86,7 +78,7 @@ impl Validator for ConstrainedIntValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_int = input.validate_int(state.strict_or(self.strict))?; + let either_int = input.validate_int(state.strict_or(self.strict))?.unpack(state); let int_value = either_int.as_int()?; if let Some(ref multiple_of) = self.multiple_of { @@ -147,17 +139,9 @@ impl Validator for ConstrainedIntValidator { Ok(either_int.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-int" } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl ConstrainedIntValidator { diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index e64d0717c..189589d6a 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -8,8 +8,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IsInstanceValidator { @@ -83,15 +82,7 @@ impl Validator for IsInstanceValidator { } } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - false - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index 0866fa1e7..7a89ef36c 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -6,8 +6,7 @@ use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] pub struct IsSubclassValidator { @@ -62,15 +61,7 @@ impl Validator for IsSubclassValidator { } } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - false - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/json.rs b/src/validators/json.rs index fd832f874..9dfb5fae2 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -6,8 +6,7 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug)] pub struct JsonValidator { @@ -61,22 +60,7 @@ impl Validator for JsonValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if let Some(ref v) = self.validator { - v.different_strict_behavior(ultra_strict) - } else { - false - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - match &self.validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index cd952bed1..302cbdaf6 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -7,9 +7,7 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; -use super::InputType; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, Validator}; +use super::{build_validator, BuildValidator, CombinedValidator, InputType, ValidationState, Validator}; #[derive(Debug)] pub struct JsonOrPython { @@ -63,16 +61,7 @@ impl Validator for JsonOrPython { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.json.different_strict_behavior(ultra_strict) || self.python.different_strict_behavior(ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.json.complete()?; - self.python.complete() - } } diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index b5cec61be..78021cd4c 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -7,6 +7,7 @@ use crate::errors::ValResult; use crate::input::Input; use crate::tools::SchemaDict; +use super::Exactness; use super::ValidationState; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; @@ -64,24 +65,20 @@ impl Validator for LaxOrStrictValidator { if state.strict_or(self.strict) { self.strict_validator.validate(py, input, state) } else { + // horrible edge case: if doing smart union validation, we need to try the strict validator + // anyway and prefer that if it succeeds + if state.exactness.is_some() { + if let Ok(strict_result) = self.strict_validator.validate(py, input, state) { + return Ok(strict_result); + } + // this is now known to be not strict + state.floor_exactness(Exactness::Lax); + } self.lax_validator.validate(py, input, state) } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.strict_validator.different_strict_behavior(true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.lax_validator.complete()?; - self.strict_validator.complete() - } } diff --git a/src/validators/list.rs b/src/validators/list.rs index c0af5fcb5..b2e0ff116 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -6,6 +6,7 @@ use pyo3::types::PyDict; use crate::errors::ValResult; use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -122,6 +123,12 @@ impl Validator for ListValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let seq = input.validate_list(state.strict_or(self.strict))?; + let exactness = match &seq { + GenericIterable::List(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::Tuple(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let output = match self.item_validator { Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "List", v, state)?, @@ -138,17 +145,6 @@ impl Validator for ListValidator { Ok(output.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(true), - None => false, - } - } else { - true - } - } - fn get_name(&self) -> &str { // The logic here is a little janky, it's done to try to cache the formatted name // while also trying to render definitions correctly when possible. @@ -167,11 +163,4 @@ impl Validator for ListValidator { } } } - - fn complete(&self) -> PyResult<()> { - if let Some(v) = &self.item_validator { - v.complete()?; - } - Ok(()) - } } diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 19bbe91f5..c9a846695 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -48,8 +48,8 @@ impl LiteralLookup { for (k, v) in expected { let id = values.len(); values.push(v); - if let Ok(bool) = k.strict_bool() { - if bool { + if let Ok(bool) = k.validate_bool(true) { + if bool.into_inner() { expected_bool.true_id = Some(id); } else { expected_bool.false_id = Some(id); @@ -97,8 +97,8 @@ impl LiteralLookup { input: &'data I, ) -> ValResult<'data, Option<(&'data I, &T)>> { if let Some(expected_bool) = &self.expected_bool { - if let Ok(bool_value) = input.strict_bool() { - if bool_value { + if let Ok(bool_value) = input.validate_bool(true) { + if bool_value.into_inner() { if let Some(true_value) = &expected_bool.true_id { return Ok(Some((input, &self.values[*true_value]))); } @@ -198,17 +198,9 @@ impl Validator for LiteralValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } pub fn expected_repr_name(mut repr_args: Vec, base_name: &'static str) -> (String, String) { diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 2a4cbf165..f541ea45d 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -58,10 +58,9 @@ mod uuid; mod validation_state; mod with_default; +pub use self::validation_state::{Exactness, ValidationState}; pub use with_default::DefaultType; -pub use self::validation_state::ValidationState; - #[pyclass(module = "pydantic_core._pydantic_core", name = "Some")] pub struct PySome { #[pyo3(get)] @@ -120,10 +119,6 @@ impl SchemaValidator { let validator = build_validator(schema, config, &mut definitions_builder)?; let definitions = definitions_builder.finish()?; - validator.complete()?; - for val in definitions.values() { - val.get().unwrap().complete()?; - } let py_schema = schema.into_py(py); let py_config = match config { Some(c) if !c.is_empty() => Some(c.into_py(py)), @@ -271,7 +266,6 @@ impl SchemaValidator { data: None, strict, from_attributes, - ultra_strict: false, context, self_instance: None, }; @@ -290,7 +284,6 @@ impl SchemaValidator { data: None, strict, from_attributes: None, - ultra_strict: false, context, self_instance: None, }; @@ -402,10 +395,6 @@ impl<'py> SelfValidator<'py> { Err(err) => return py_schema_err!("Error building self-schema:\n {}", err), }; let definitions = definitions_builder.finish()?; - validator.complete()?; - for val in definitions.values() { - val.get().unwrap().complete()?; - } Ok(SchemaValidator { validator, definitions, @@ -569,8 +558,6 @@ pub struct Extra<'a> { pub data: Option<&'a PyDict>, /// whether we're in strict or lax mode pub strict: Option, - /// whether we're in ultra-strict mode, only used occasionally in unions - pub ultra_strict: bool, /// Validation time setting of `from_attributes` pub from_attributes: Option, /// context used in validator functions @@ -591,7 +578,6 @@ impl<'a> Extra<'a> { input_type, data: None, strict, - ultra_strict: false, from_attributes, context, self_instance, @@ -600,12 +586,11 @@ impl<'a> Extra<'a> { } impl<'a> Extra<'a> { - pub fn as_strict(&self, ultra_strict: bool) -> Self { + pub fn as_strict(&self) -> Self { Self { input_type: self.input_type, data: self.data, strict: Some(true), - ultra_strict, from_attributes: self.from_attributes, context: self.context, self_instance: self.self_instance, @@ -742,15 +727,7 @@ pub trait Validator: Send + Sync + Debug { Err(py_err.into()) } - /// whether the validator behaves differently in strict mode, and in ultra strict mode - /// implementations should return true if any of their sub-validators return true - fn different_strict_behavior(&self, ultra_strict: bool) -> bool; - /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; - - /// this method must be implemented for any validator which holds references to other validators, - /// it is used by `UnionValidator` to calculate strictness - fn complete(&self) -> PyResult<()>; } diff --git a/src/validators/model.rs b/src/validators/model.rs index 1459f56f5..0299ce5d8 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -7,6 +7,7 @@ use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType}; use pyo3::{ffi, intern}; use super::function::convert_err; +use super::validation_state::Exactness; use super::{ build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; @@ -143,6 +144,8 @@ impl Validator for ModelValidator { Ok(input.to_object(py)) } } else { + // Having to construct a new model is not an exact match + state.floor_exactness(Exactness::Strict); self.validate_construct(py, input, None, state) } } @@ -206,21 +209,9 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.validator.different_strict_behavior(ultra_strict) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } impl ModelValidator { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 17ec81670..a284bd4e9 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -10,13 +10,14 @@ use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, StringMappingGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::ValidationState; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, +}; use std::ops::ControlFlow; @@ -250,7 +251,7 @@ impl Validator for ModelFieldsValidator { let model_extra_dict = PyDict::new(py); for item_result in <$iter>::new($dict)? { let (raw_key, value) = item_result?; - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -433,21 +434,7 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - self.fields.iter().try_for_each(|f| f.validator.complete())?; - match &self.extras_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } diff --git a/src/validators/none.rs b/src/validators/none.rs index f241be9d8..f6891292b 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -36,15 +36,7 @@ impl Validator for NoneValidator { } } - fn different_strict_behavior(&self, _ultra_strict: bool) -> bool { - false - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 7f4cf19fc..85fbd6c26 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -45,15 +45,7 @@ impl Validator for NullableValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.validator.different_strict_behavior(ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } diff --git a/src/validators/set.rs b/src/validators/set.rs index 9270ea204..d29c60c3f 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -2,8 +2,9 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PySet}; use crate::errors::ValResult; -use crate::input::Input; +use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::min_length_check; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -64,25 +65,19 @@ impl Validator for SetValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_set(state.strict_or(self.strict))?; + let exactness = match &collection { + GenericIterable::Set(_) => Exactness::Exact, + GenericIterable::FrozenSet(_) | GenericIterable::JsonArray(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let set = PySet::empty(py)?; collection.validate_to_set(py, set, input, self.max_length, "Set", &self.item_validator, state)?; min_length_check!(input, "Set", self.min_length, set); Ok(set.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - self.item_validator.different_strict_behavior(true) - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.item_validator.complete() - } } diff --git a/src/validators/string.rs b/src/validators/string.rs index 0be51ece8..98d8a9d99 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -47,21 +47,14 @@ impl Validator for StrValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; - Ok(either_str.into_py(py)) - } - - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict + input + .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str) + .map(|val_match| val_match.unpack(state).into_py(py)) } fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } /// Any new properties set here must be reflected in `has_constraints_set` @@ -86,7 +79,9 @@ impl Validator for StrConstrainedValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let either_str = input.validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)?; + let either_str = input + .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)? + .unpack(state); let cow = either_str.as_cow()?; let mut str = cow.as_ref(); if self.strip_whitespace { @@ -146,17 +141,9 @@ impl Validator for StrConstrainedValidator { Ok(py_string.into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { "constrained-str" } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl StrConstrainedValidator { diff --git a/src/validators/time.rs b/src/validators/time.rs index f5e2be7c7..abf82091f 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -46,7 +46,9 @@ impl Validator for TimeValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let time = input.validate_time(state.strict_or(self.strict), self.microseconds_precision)?; + let time = input + .validate_time(state.strict_or(self.strict), self.microseconds_precision)? + .unpack(state); if let Some(constraints) = &self.constraints { let raw_time = time.as_raw()?; @@ -78,17 +80,9 @@ impl Validator for TimeValidator { Ok(time.try_into_py(py)?) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } fn convert_pytime(schema: &PyDict, field: &PyString) -> PyResult> { diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index a58d98f7a..f04fef91c 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -71,7 +71,9 @@ impl Validator for TimeDeltaValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - let timedelta = input.validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)?; + let timedelta = input + .validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)? + .unpack(state); let py_timedelta = timedelta.try_into_py(py)?; if let Some(constraints) = &self.constraints { let raw_timedelta = timedelta.to_duration()?; @@ -101,17 +103,9 @@ impl Validator for TimeDeltaValidator { Ok(py_timedelta.into()) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } fn pydelta_to_human_readable(py_delta: &PyDelta) -> String { let total_seconds = py_delta.get_seconds(); diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 1b7cf9b00..9513582e5 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -6,6 +6,7 @@ use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; +use crate::validators::Exactness; use super::list::{get_items_schema, min_length_check}; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -51,6 +52,12 @@ impl Validator for TupleVariableValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let seq = input.validate_tuple(state.strict_or(self.strict))?; + let exactness = match &seq { + GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::List(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); let output = match self.item_validator { Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "Tuple", v, state)?, @@ -60,27 +67,9 @@ impl Validator for TupleVariableValidator { Ok(PyTuple::new(py, &output).into_py(py)) } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - match self.item_validator { - Some(ref v) => v.different_strict_behavior(true), - None => false, - } - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - match &self.item_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } #[derive(Debug)] @@ -199,6 +188,13 @@ impl Validator for TuplePositionalValidator { state: &mut ValidationState, ) -> ValResult<'data, PyObject> { let collection = input.validate_tuple(state.strict_or(self.strict))?; + let exactness: crate::validators::Exactness = match &collection { + GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, + GenericIterable::List(_) => Exactness::Strict, + _ => Exactness::Lax, + }; + state.floor_exactness(exactness); + let actual_length = collection.generic_len(); let expected_length = if self.extras_validator.is_some() { actual_length.unwrap_or(self.items_validators.len()) @@ -238,29 +234,7 @@ impl Validator for TuplePositionalValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - if ultra_strict { - if self.items_validators.iter().any(|v| v.different_strict_behavior(true)) { - true - } else if let Some(ref v) = self.extras_validator { - v.different_strict_behavior(true) - } else { - false - } - } else { - true - } - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.items_validators.iter().try_for_each(CombinedValidator::complete)?; - match &self.extras_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 5839959e7..f55b7d717 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -11,7 +11,7 @@ use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, Ext use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, - MappingGenericIterator, StringMappingGenericIterator, + MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, }; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -252,7 +252,7 @@ impl Validator for TypedDictValidator { if let Some(ref mut used_keys) = used_keys { for item_result in <$iter>::new($dict)? { let (raw_key, value) = item_result?; - let either_str = match raw_key.strict_str() { + let either_str = match raw_key.validate_str(true, false).map(ValidationMatch::into_inner) { Ok(k) => k, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { @@ -327,21 +327,7 @@ impl Validator for TypedDictValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.fields - .iter() - .any(|f| f.validator.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - self.fields.iter().try_for_each(|f| f.validator.complete())?; - match &self.extras_validator { - Some(v) => v.complete(), - None => Ok(()), - } - } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 837114408..0f8fded07 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -1,6 +1,5 @@ use std::fmt::Write; use std::str::FromStr; -use std::sync::atomic::{AtomicBool, Ordering}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyString, PyTuple}; @@ -17,48 +16,22 @@ use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator, +}; #[derive(Debug)] enum UnionMode { - Smart { - strict_required: AtomicBool, - ultra_strict_required: AtomicBool, - }, + Smart, LeftToRight, } -impl UnionMode { - // construct smart with some default values - const fn default_smart() -> Self { - Self::Smart { - strict_required: AtomicBool::new(true), - ultra_strict_required: AtomicBool::new(false), - } - } -} - -impl Clone for UnionMode { - fn clone(&self) -> Self { - match self { - Self::Smart { - strict_required, - ultra_strict_required, - } => Self::Smart { - strict_required: AtomicBool::new(strict_required.load(Ordering::SeqCst)), - ultra_strict_required: AtomicBool::new(ultra_strict_required.load(Ordering::SeqCst)), - }, - Self::LeftToRight => Self::LeftToRight, - } - } -} - impl FromStr for UnionMode { type Err = PyErr; fn from_str(s: &str) -> Result { match s { - "smart" => Ok(Self::default_smart()), + "smart" => Ok(Self::Smart), "left_to_right" => Ok(Self::LeftToRight), s => py_schema_err!("Invalid union mode: `{}`, expected `smart` or `left_to_right`", s), } @@ -103,7 +76,7 @@ impl BuildValidator for UnionValidator { let auto_collapse = || schema.get_as_req(intern!(py, "auto_collapse")).unwrap_or(true); let mode = schema .get_as::<&str>(intern!(py, "mode"))? - .map_or(Ok(UnionMode::default_smart()), UnionMode::from_str)?; + .map_or(Ok(UnionMode::Smart), UnionMode::from_str)?; match choices.len() { 0 => py_schema_err!("One or more union choices required"), 1 if auto_collapse() => Ok(choices.into_iter().next().unwrap().0), @@ -128,71 +101,74 @@ impl BuildValidator for UnionValidator { } impl UnionValidator { - fn validate_smart<'s, 'data>( - &'s self, + fn validate_smart<'data>( + &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - strict_required: bool, - ultra_strict_required: bool, ) -> ValResult<'data, PyObject> { - if ultra_strict_required { - // do an ultra strict check first - let state = &mut state.rebind_extra(|extra| { - extra.strict = Some(true); - extra.ultra_strict = true; - }); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| validator.validate(py, input, state)) - .find(ValResult::is_ok) - { - return res; - } - } - + let old_exactness = state.exactness; + let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - if state.strict_or(self.strict) { - let state = &mut state.rebind_extra(|extra| extra.strict = Some(true)); - for (validator, label) in &self.choices { - match validator.validate(py, input, state) { - Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines), - otherwise => return otherwise, - }; - } + let mut success = None; - Err(errors.into_val_error(input)) - } else { - if strict_required { - // 1st pass: check if the value is an exact instance of one of the Union types, - // e.g. use validate in strict mode - let state = &mut state.rebind_extra(|extra| extra.strict = Some(true)); - if let Some(res) = self - .choices - .iter() - .map(|(validator, _label)| validator.validate(py, input, state)) - .find(ValResult::is_ok) - { - return res; + for (choice, label) in &self.choices { + let state = &mut state.rebind_extra(|extra| { + if strict { + extra.strict = Some(strict); } + }); + state.exactness = Some(Exactness::Exact); + let result = choice.validate(py, input, state); + match result { + Ok(new_success) => match state.exactness { + // exact match, return + Some(Exactness::Exact) => { + return { + // exact match, return, restore any previous exactness + state.exactness = old_exactness; + Ok(new_success) + }; + } + _ => { + // success should always have an exactness + debug_assert_ne!(state.exactness, None); + let new_exactness = state.exactness.unwrap_or(Exactness::Lax); + // if the new result has higher exactness than the current success, replace it + if success + .as_ref() + .map_or(true, |(_, current_exactness)| *current_exactness < new_exactness) + { + // TODO: is there a possible optimization here, where once there has + // been one success, we turn on strict mode, to avoid unnecessary + // coercions for further validation? + success = Some((new_success, new_exactness)); + } + } + }, + Err(ValError::LineErrors(lines)) => { + // if we don't yet know this validation will succeed, record the error + if success.is_none() { + errors.push(choice, label.as_deref(), lines); + } + } + otherwise => return otherwise, } + } + state.exactness = old_exactness; - // 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate - for (validator, label) in &self.choices { - match validator.validate(py, input, state) { - Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines), - otherwise => return otherwise, - }; - } - - Err(errors.into_val_error(input)) + if let Some((success, exactness)) = success { + state.floor_exactness(exactness); + return Ok(success); } + + // no matches, build errors + Err(errors.into_val_error(input)) } - fn validate_left_to_right<'s, 'data>( - &'s self, + fn validate_left_to_right<'data>( + &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, @@ -232,50 +208,15 @@ impl Validator for UnionValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult<'data, PyObject> { - match &self.mode { - UnionMode::Smart { - strict_required, - ultra_strict_required, - } => self.validate_smart( - py, - input, - state, - strict_required.load(Ordering::SeqCst), - ultra_strict_required.load(Ordering::SeqCst), - ), + match self.mode { + UnionMode::Smart => self.validate_smart(py, input, state), UnionMode::LeftToRight => self.validate_left_to_right(py, input, state), } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.choices - .iter() - .any(|(v, _)| v.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.choices.iter().try_for_each(|(v, _)| v.complete())?; - if let UnionMode::Smart { - strict_required, - ultra_strict_required, - } = &self.mode - { - strict_required.store( - self.choices.iter().any(|(v, _)| v.different_strict_behavior(false)), - Ordering::SeqCst, - ); - ultra_strict_required.store( - self.choices.iter().any(|(v, _)| v.different_strict_behavior(true)), - Ordering::SeqCst, - ); - } - - Ok(()) - } } struct ChoiceLineErrors<'a, 'data> { @@ -494,20 +435,9 @@ impl Validator for TaggedUnionValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.lookup - .values - .iter() - .any(|v| v.different_strict_behavior(ultra_strict)) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.lookup.values.iter().try_for_each(CombinedValidator::complete) - } } impl TaggedUnionValidator { @@ -519,7 +449,7 @@ impl TaggedUnionValidator { let dict = input.strict_dict()?; let either_tag = match dict { GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type"))? { - Some(t) => t.strict_str()?, + Some(t) => t.validate_str(true, false)?.into_inner(), None => return Err(self.tag_not_found(input)), }, _ => unreachable!(), @@ -530,7 +460,7 @@ impl TaggedUnionValidator { if tag == "function" || tag == "tuple" { let mode = match dict { GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode"))? { - Some(m) => Some(m.strict_str()?), + Some(m) => Some(m.validate_str(true, false)?.into_inner()), None => None, }, _ => unreachable!(), diff --git a/src/validators/url.rs b/src/validators/url.rs index b5fe8bf4c..77f887b7d 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -16,6 +16,7 @@ use crate::tools::SchemaDict; use crate::url::{schema_is_special, PyMultiHostUrl, PyUrl}; use super::literal::expected_repr_name; +use super::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; type AllowedSchemas = Option<(AHashSet, String)>; @@ -87,28 +88,25 @@ impl Validator for UrlValidator { self.default_port, &self.default_path, ) { - Ok(()) => Ok(PyUrl::new(lib_url).into_py(py)), + Ok(()) => { + // Lax rather than strict to preserve V2.4 semantic that str wins over url in union + state.floor_exactness(Exactness::Lax); + Ok(PyUrl::new(lib_url).into_py(py)) + } Err(error_type) => return Err(ValError::new(error_type, input)), } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl UrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, Url> { match input.validate_str(strict, false) { - Ok(either_str) => { + Ok(val_match) => { + let either_str = val_match.into_inner(); let cow = either_str.as_cow()?; let url_str = cow.as_ref(); @@ -223,28 +221,25 @@ impl Validator for MultiHostUrlValidator { self.default_port, &self.default_path, ) { - Ok(()) => Ok(multi_url.into_py(py)), + Ok(()) => { + // Lax rather than strict to preserve V2.4 semantic that str wins over url in union + state.floor_exactness(Exactness::Lax); + Ok(multi_url.into_py(py)) + } Err(error_type) => return Err(ValError::new(error_type, input)), } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl MultiHostUrlValidator { fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, PyMultiHostUrl> { match input.validate_str(strict, false) { - Ok(either_str) => { + Ok(val_match) => { + let either_str = val_match.into_inner(); let cow = either_str.as_cow()?; let url_str = cow.as_ref(); diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 94d302438..9e4ce9fb5 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -13,7 +13,7 @@ use crate::tools::SchemaDict; use super::model::create_class; use super::model::force_setattr; -use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator}; const UUID_INT: &str = "int"; const UUID_IS_SAFE: &str = "is_safe"; @@ -117,22 +117,19 @@ impl Validator for UuidValidator { input, )) } else { + // In python mode this is a coercion, in JSON mode we treat a UUID string as an + // exact match + if input.is_python() { + state.floor_exactness(Exactness::Lax); + } let uuid = self.get_uuid(input)?; self.create_py_uuid(py, class, &uuid) } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - !ultra_strict - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } - - fn complete(&self) -> PyResult<()> { - Ok(()) - } } impl UuidValidator { @@ -154,7 +151,8 @@ impl UuidValidator { None => { let either_bytes = input .validate_bytes(true) - .map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))?; + .map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))? + .into_inner(); let bytes_slice = either_bytes.as_slice(); 'parse: { // Try parsing as utf8, but don't care if it fails diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 79ec8b87a..aacd7d2af 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -2,15 +2,27 @@ use crate::recursion_guard::RecursionGuard; use super::Extra; +#[derive(Clone, Copy, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] +pub enum Exactness { + Lax, + Strict, + Exact, +} + pub struct ValidationState<'a> { pub recursion_guard: &'a mut RecursionGuard, + pub exactness: Option, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { - Self { recursion_guard, extra } + Self { + recursion_guard, // Don't care about exactness unless doing union validation + exactness: None, + extra, + } } pub fn with_new_extra<'r, R: 'r>( @@ -22,9 +34,15 @@ impl<'a> ValidationState<'a> { // but lifetimes get in a tangle. Maybe someone brave wants to have a go at unpicking lifetimes. let mut new_state = ValidationState { recursion_guard: self.recursion_guard, + exactness: self.exactness, extra, }; - f(&mut new_state) + let result = f(&mut new_state); + match new_state.exactness { + Some(exactness) => self.floor_exactness(exactness), + None => self.exactness = None, + } + result } /// Temporarily rebinds the extra field by calling `f` to modify extra. @@ -47,6 +65,23 @@ impl<'a> ValidationState<'a> { pub fn strict_or(&self, default: bool) -> bool { self.extra.strict.unwrap_or(default) } + + /// Sets the exactness to the lower of the current exactness + /// and the given exactness. + /// + /// This is designed to be used in union validation, where the + /// idea is that the "most exact" validation wins. + pub fn floor_exactness(&mut self, exactness: Exactness) { + match self.exactness { + None | Some(Exactness::Lax) => {} + Some(Exactness::Strict) => { + if exactness == Exactness::Lax { + self.exactness = Some(Exactness::Lax); + } + } + Some(Exactness::Exact) => self.exactness = Some(exactness), + } + } } pub struct ValidationStateWithReboundExtra<'state, 'a> { diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index 36b275dd1..a06ccd0cd 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -182,17 +182,9 @@ impl Validator for WithDefaultValidator { } } - fn different_strict_behavior(&self, ultra_strict: bool) -> bool { - self.validator.different_strict_behavior(ultra_strict) - } - fn get_name(&self) -> &str { &self.name } - - fn complete(&self) -> PyResult<()> { - self.validator.complete() - } } impl WithDefaultValidator { diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 2d676ac17..9f7a93d57 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -1085,3 +1085,22 @@ def test_complex_recursive_type() -> None: 'input': {'a': datetime.date(1992, 12, 11)}, }, ] + + +def test_no_exponential_blowup(): + """See https://github.com/pydantic/pydantic/issues/8049 + + There was a performance bug which led to exponential blowup when trying to + build a schema with many intermingled recursive definitions. + """ + unions = core_schema.union_schema([core_schema.definition_reference_schema(f'foo_{i}') for i in range(100)]) + + schema = core_schema.definitions_schema( + core_schema.typed_dict_schema({'x': core_schema.typed_dict_field(unions)}), + definitions=[ + core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(unions)}, ref=f'foo_{i}') + for i in range(100) + ], + ) + + SchemaValidator(schema) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index ad51fb447..503a5f387 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,3 +1,9 @@ +from dataclasses import dataclass +from datetime import date, time +from enum import Enum, IntEnum +from typing import Any +from uuid import UUID + import pytest from dirty_equals import IsFloat, IsInt @@ -342,9 +348,6 @@ def test_dirty_behaviour(): def test_int_float(): v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.float_schema()])) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour - assert v.validate_python(1) == IsInt(approx=1, delta=0) assert v.validate_json('1') == IsInt(approx=1, delta=0) assert v.validate_python(1.0) == IsFloat(approx=1, delta=0) @@ -382,17 +385,8 @@ def test_str_float(): assert v.validate_json('"1"') == '1' -def test_strict_check(): - v = SchemaValidator(core_schema.union_schema([core_schema.int_schema(), core_schema.json_schema()])) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:false' in plain_repr(v) - - def test_no_strict_check(): v = SchemaValidator(core_schema.union_schema([core_schema.is_instance_schema(int), core_schema.json_schema()])) - assert 'strict_required:false' in plain_repr(v) - assert 'ultra_strict_required:false' in plain_repr(v) - assert v.validate_python(123) == 123 assert v.validate_python('[1, 2, 3]') == [1, 2, 3] @@ -414,8 +408,6 @@ def test_strict_reference(): ], ) ) - assert 'strict_required:true' in plain_repr(v) - assert 'ultra_strict_required:true' in plain_repr(v) # since "float" schema has ultra-strict behaviour assert repr(v.validate_python((1, 2))) == '(1.0, 2)' assert repr(v.validate_python((1.0, (2.0, 3)))) == '(1.0, (2.0, 3))' @@ -501,3 +493,281 @@ def test_left_to_right_union_strict(): out = v.validate_python(1) assert out == 1.0 assert isinstance(out, float) + + +def test_union_function_before_called_once(): + # See https://github.com/pydantic/pydantic/issues/6830 - in particular the + # smart union validator used to call `remove_prefix` twice, which is not + # ideal from a user perspective. + class SpecialValues(str, Enum): + DEFAULT = 'default' + OTHER = 'other' + + special_values_schema = core_schema.no_info_after_validator_function(SpecialValues, core_schema.str_schema()) + + validator_called_count = 0 + + def remove_prefix(v: str): + nonlocal validator_called_count + validator_called_count += 1 + if v.startswith('uuid::'): + return v[6:] + return v + + prefixed_uuid_schema = core_schema.no_info_before_validator_function(remove_prefix, core_schema.uuid_schema()) + + v = SchemaValidator(core_schema.union_schema([special_values_schema, prefixed_uuid_schema])) + + assert v.validate_python('uuid::12345678-1234-5678-1234-567812345678') == UUID( + '12345678-1234-5678-1234-567812345678' + ) + assert validator_called_count == 1 + + +@pytest.mark.parametrize( + ('schema', 'input_value', 'expected_value'), + ( + ( + core_schema.uuid_schema(), + '12345678-1234-5678-1234-567812345678', + UUID('12345678-1234-5678-1234-567812345678'), + ), + (core_schema.date_schema(), '2020-01-01', date(2020, 1, 1)), + (core_schema.time_schema(), '00:00:00', time(0, 0, 0)), + # In V2.4 these already returned strings, so we keep this behaviour in V2 + (core_schema.datetime_schema(), '2020-01-01:00:00:00', '2020-01-01:00:00:00'), + (core_schema.url_schema(), 'https://foo.com', 'https://foo.com'), + (core_schema.multi_host_url_schema(), 'https://bar.com,foo.com', 'https://bar.com,foo.com'), + ), +) +def test_smart_union_json_string_types(schema: core_schema.CoreSchema, input_value: str, expected_value: Any): + # Many types have to be represented in strings as JSON, we make sure that + # when parsing in JSON mode these types are preferred + # TODO: in V3 we will make str win in all these cases. + + validator = SchemaValidator(core_schema.union_schema([schema, core_schema.str_schema()])) + assert validator.validate_json(f'"{input_value}"') == expected_value + # in Python mode the string will be preferred + assert validator.validate_python(input_value) == input_value + + +@pytest.mark.parametrize( + ('schema', 'input_value'), + ( + pytest.param( + core_schema.uuid_schema(), + '12345678-1234-5678-1234-567812345678', + marks=pytest.mark.xfail(reason='TODO: V3'), + ), + (core_schema.date_schema(), '2020-01-01'), + (core_schema.time_schema(), '00:00:00'), + (core_schema.datetime_schema(), '2020-01-01:00:00:00'), + (core_schema.url_schema(), 'https://foo.com'), + (core_schema.multi_host_url_schema(), 'https://bar.com,foo.com'), + ), +) +def test_smart_union_json_string_types_str_first(schema: core_schema.CoreSchema, input_value: str): + # As above, but reversed order; str should always win + validator = SchemaValidator(core_schema.union_schema([core_schema.str_schema(), schema])) + assert validator.validate_json(f'"{input_value}"') == input_value + assert validator.validate_python(input_value) == input_value + + +def test_smart_union_default_fallback(): + """Using a default value does not affect the exactness of the smart union match.""" + + class ModelA: + x: int + y: int = 1 + + class ModelB: + x: int + + schema = core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field(core_schema.int_schema()), + 'y': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=1) + ), + } + ), + ), + core_schema.model_schema( + ModelB, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.int_schema())}) + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + assert result.y == 1 + + # passing a ModelB explicitly will not match the default value + b = ModelB() + assert validator.validate_python(b) is b + + +def test_smart_union_model_field(): + class ModelA: + x: int + + class ModelB: + x: str + + schema = core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.int_schema())}) + ), + core_schema.model_schema( + ModelB, core_schema.model_fields_schema({'x': core_schema.model_field(core_schema.str_schema())}) + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + + result = validator.validate_python({'x': '1'}) + assert isinstance(result, ModelB) + assert result.x == '1' + + +def test_smart_union_dataclass_field(): + @dataclass + class ModelA: + x: int + + @dataclass + class ModelB: + x: str + + schema = core_schema.union_schema( + [ + core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', [core_schema.dataclass_field('x', core_schema.int_schema())] + ), + ['x'], + ), + core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', [core_schema.dataclass_field('x', core_schema.str_schema())] + ), + ['x'], + ), + ] + ) + + validator = SchemaValidator(schema) + + result = validator.validate_python({'x': 1}) + assert isinstance(result, ModelA) + assert result.x == 1 + + result = validator.validate_python({'x': '1'}) + assert isinstance(result, ModelB) + assert result.x == '1' + + +def test_smart_union_with_any(): + """any is preferred over lax validations""" + + # str not coerced to int + schema = core_schema.union_schema([core_schema.int_schema(), core_schema.any_schema()]) + validator = SchemaValidator(schema) + assert validator.validate_python('1') == '1' + + # int *is* coerced to float, this is a strict validation + schema = core_schema.union_schema([core_schema.float_schema(), core_schema.any_schema()]) + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '1.0' + + +def test_smart_union_validator_function(): + """adding a validator function should not change smart union behaviour""" + + inner_schema = core_schema.union_schema([core_schema.int_schema(), core_schema.float_schema()]) + + validator = SchemaValidator(inner_schema) + assert repr(validator.validate_python(1)) == '1' + assert repr(validator.validate_python(1.0)) == '1.0' + + schema = core_schema.union_schema( + [core_schema.no_info_after_validator_function(lambda v: v * 2, inner_schema), core_schema.str_schema()] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '2.0' + assert validator.validate_python('1') == '1' + + schema = core_schema.union_schema( + [ + core_schema.no_info_wrap_validator_function(lambda v, handler: handler(v) * 2, inner_schema), + core_schema.str_schema(), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '2.0' + assert validator.validate_python('1') == '1' + + +def test_smart_union_validator_function_one_arm(): + """adding a validator function should not change smart union behaviour""" + + schema = core_schema.union_schema( + [ + core_schema.float_schema(), + core_schema.no_info_after_validator_function(lambda v: v * 2, core_schema.int_schema()), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '1.0' + + schema = core_schema.union_schema( + [ + core_schema.float_schema(), + core_schema.no_info_wrap_validator_function(lambda v, handler: handler(v) * 2, core_schema.int_schema()), + ] + ) + + validator = SchemaValidator(schema) + assert repr(validator.validate_python(1)) == '2' + assert repr(validator.validate_python(1.0)) == '1.0' + + +def test_int_not_coerced_to_enum(): + class BinaryEnum(IntEnum): + ZERO = 0 + ONE = 1 + + enum_schema = core_schema.lax_or_strict_schema( + core_schema.no_info_after_validator_function(BinaryEnum, core_schema.int_schema()), + core_schema.is_instance_schema(BinaryEnum), + ) + + schema = core_schema.union_schema([enum_schema, core_schema.int_schema()]) + + validator = SchemaValidator(schema) + + assert validator.validate_python(0) is not BinaryEnum.ZERO + assert validator.validate_python(1) is not BinaryEnum.ONE + assert validator.validate_python(BinaryEnum.ZERO) is BinaryEnum.ZERO + assert validator.validate_python(BinaryEnum.ONE) is BinaryEnum.ONE From d3416f7bf90ce4b9cc84422c63d77fa4a198c052 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 8 Nov 2023 21:56:19 +0000 Subject: [PATCH 111/550] bump version to 2.14.0 (#1066) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 13874f1fd..f27eb2fb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.13.0" +version = "2.14.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 6c1d1d948..d2a97ec9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.13.0" +version = "2.14.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 5712edfd495629dc3f9e2a52537415a9c02d9461 Mon Sep 17 00:00:00 2001 From: Luca Blight <46580497+Luca-Blight@users.noreply.github.com> Date: Thu, 9 Nov 2023 04:23:28 -0500 Subject: [PATCH 112/550] Adopt ruff formatter instead of black (ready for review) (#1051) --- .github/workflows/ci.yml | 6 +++--- Makefile | 12 ++++++------ pyproject.toml | 11 +++-------- python/pydantic_core/_pydantic_core.pyi | 1 - tests/requirements-linting.txt | 5 ++--- tests/serializers/test_functions.py | 2 +- tests/test_errors.py | 6 +++--- tests/validators/test_arguments.py | 8 ++++---- tests/validators/test_bool.py | 2 +- tests/validators/test_int.py | 6 +++--- tests/validators/test_literal.py | 2 +- tests/validators/test_string.py | 4 ++-- tests/validators/test_typed_dict.py | 6 +++--- 13 files changed, 32 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fa9dfa0e0..de0e52e7a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -315,7 +315,7 @@ jobs: version: '3.1.32' actions-cache-folder: emsdk-cache - - run: pip install 'maturin>=1,<2' 'black>=22.3.0,<23' typing_extensions + - run: pip install 'maturin>=1,<2' 'ruff==0.1.3' typing_extensions - name: build wheels run: make build-wasm @@ -442,7 +442,7 @@ jobs: python-version: '3.11' architecture: ${{ matrix.python-architecture || 'x64' }} - - run: pip install -U twine 'black>=22.3.0,<23' typing_extensions + - run: pip install -U twine 'ruff==0.1.3' typing_extensions # generate self-schema now, so we don't have to do so inside docker in maturin build - run: python generate_self_schema.py @@ -501,7 +501,7 @@ jobs: with: components: llvm-tools - - run: pip install -U 'black>=22.3.0,<23' typing_extensions + - run: pip install -U 'ruff==0.1.3' typing_extensions # generate self-schema now, so we don't have to do so inside docker in maturin build - run: python generate_self_schema.py diff --git a/Makefile b/Makefile index 4e97e6f42..d9e0d0e0a 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .DEFAULT_GOAL := all -black = black python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py -ruff = ruff python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py +sources = python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py + mypy-stubtest = python -m mypy.stubtest pydantic_core._pydantic_core --allowlist .mypy-stubtest-allowlist # using pip install cargo (via maturin via pip) doesn't get the tty handle @@ -90,14 +90,14 @@ build-wasm: .PHONY: format format: - $(black) - $(ruff) --fix --exit-zero + ruff --fix $(sources) + ruff format $(sources) cargo fmt .PHONY: lint-python lint-python: - $(ruff) - $(black) --check --diff + ruff $(sources) + ruff format --check $(sources) $(mypy-stubtest) griffe dump -f -d google -LWARNING -o/dev/null python/pydantic_core diff --git a/pyproject.toml b/pyproject.toml index 7afacd57a..a40f160f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,13 +57,15 @@ features = ["pyo3/extension-module"] line-length = 120 extend-select = ['Q', 'RUF100', 'C90', 'I'] extend-ignore = [ - 'E501', # ignore line too long and let black handle it 'E721', # using type() instead of isinstance() - we use this in tests ] flake8-quotes = {inline-quotes = 'single', multiline-quotes = 'double'} mccabe = { max-complexity = 13 } isort = { known-first-party = ['pydantic_core', 'tests'] } +[tool.ruff.format] +quote-style = 'single' + [tool.pytest.ini_options] testpaths = 'tests' log_format = '%(name)s %(levelname)s: %(message)s' @@ -97,13 +99,6 @@ exclude_lines = [ '@overload', ] -[tool.black] -color = true -line-length = 120 -target-version = ['py37', 'py38', 'py39', 'py310'] -skip-string-normalization = true -skip-magic-trailing-comma = true - # configuring https://github.com/pydantic/hooky [tool.hooky] assignees = ['samuelcolvin', 'adriangb', 'dmontagu', 'davidhewitt', 'lig'] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index f28b7a12a..b452d2f17 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -847,7 +847,6 @@ def list_all_errors() -> list[ErrorTypeInfo]: Returns: A list of `ErrorTypeInfo` typed dicts. """ - @final class TzInfo(datetime.tzinfo): def tzname(self, _dt: datetime.datetime | None) -> str | None: ... diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 655ff8bda..332398740 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,5 +1,4 @@ -black==23.10.1 -griffe==0.36.9 -pyright==1.1.334 +griffe==0.36.7 +pyright==1.1.332 ruff==0.1.3 mypy==1.6.1 diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index 318254602..8851a7d36 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -228,7 +228,7 @@ def append_args(value, info): 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_json(123) == ( - b'"123 info=SerializationInfo(include=None, exclude=None, mode=\'json\', by_alias=True, exclude_unset=False, ' + b"\"123 info=SerializationInfo(include=None, exclude=None, mode='json', by_alias=True, exclude_unset=False, " b'exclude_defaults=False, exclude_none=False, round_trip=False)"' ) diff --git a/tests/test_errors.py b/tests/test_errors.py index 0d7a966e1..05815aec5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1050,9 +1050,9 @@ def test_loc_with_dots(pydantic_version): ] # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for typed-dict\n" - "`foo.bar`.0\n" - " Input should be a valid integer, unable to parse string as an integer " + '1 validation error for typed-dict\n' + '`foo.bar`.0\n' + ' Input should be a valid integer, unable to parse string as an integer ' "[type=int_parsing, input_value='x', input_type=str]\n" f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing' ) diff --git a/tests/validators/test_arguments.py b/tests/validators/test_arguments.py index 4ef581b47..4cf1d3ad2 100644 --- a/tests/validators/test_arguments.py +++ b/tests/validators/test_arguments.py @@ -1009,11 +1009,11 @@ def test_error_display(pydantic_version): ] # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for arguments\n" - "b\n" - " Missing required argument [type=missing_argument, " + '1 validation error for arguments\n' + 'b\n' + ' Missing required argument [type=missing_argument, ' "input_value=ArgsKwargs((), {'a': 1}), input_type=ArgsKwargs]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing_argument" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing_argument' ) # insert_assert(exc_info.value.json(include_url=False)) assert exc_info.value.json(include_url=False) == ( diff --git a/tests/validators/test_bool.py b/tests/validators/test_bool.py index e71d41cb1..3ac900701 100644 --- a/tests/validators/test_bool.py +++ b/tests/validators/test_bool.py @@ -63,7 +63,7 @@ def test_bool_error(pydantic_version): '1 validation error for bool\n' ' Input should be a valid boolean, ' "unable to interpret input [type=bool_parsing, input_value='wrong', input_type=str]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/bool_parsing" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/bool_parsing' ) assert exc_info.value.errors(include_url=False) == [ { diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 61acab7fb..80dd1cf73 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -357,10 +357,10 @@ def test_too_long(pydantic_version): ] # insert_assert(repr(exc_info.value)) assert repr(exc_info.value) == ( - "1 validation error for int\n" - " Unable to parse input string as an integer, exceeded maximum size " + '1 validation error for int\n' + ' Unable to parse input string as an integer, exceeded maximum size ' "[type=int_parsing_size, input_value='111111111111111111111111...11111111111111111111111', input_type=str]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing_size" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/int_parsing_size' ) diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index e1397aeea..d294f866c 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -78,7 +78,7 @@ pytest.param( ['a', 'b'], 'c', - Err("Input should be 'a' or 'b' [type=literal_error, input_value=\'c\', input_type=str]"), + Err("Input should be 'a' or 'b' [type=literal_error, input_value='c', input_type=str]"), id='wrong-multiple-str', ), ([1, '1'], 1, 1), diff --git a/tests/validators/test_string.py b/tests/validators/test_string.py index bc2102de2..22bcd5445 100644 --- a/tests/validators/test_string.py +++ b/tests/validators/test_string.py @@ -346,9 +346,9 @@ def test_backtracking_regex_rust_unsupported(mode) -> None: SchemaValidator(core_schema.str_schema(pattern=pattern), core_schema.CoreConfig(regex_engine='rust-regex')) assert exc_info.value.args[0] == ( - 'Error building \"str\" validator:\n' + 'Error building "str" validator:\n' ' SchemaError: regex parse error:\n' - ' r(#*)\".*?\"\\1\n' + ' r(#*)".*?"\\1\n' ' ^^\n' 'error: backreferences are not supported' ) diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 5f0729d25..8fb25cff6 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -107,10 +107,10 @@ def test_missing_error(pydantic_version): v.validate_python({'field_a': b'abc'}) # insert_assert(str(exc_info.value)) assert str(exc_info.value) == ( - "1 validation error for typed-dict\n" - "field_b\n" + '1 validation error for typed-dict\n' + 'field_b\n' " Field required [type=missing, input_value={'field_a': b'abc'}, input_type=dict]\n" - f" For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing" + f' For further information visit https://errors.pydantic.dev/{pydantic_version}/v/missing' ) From b27455eb90fcc18d4c728c87a8da39bd133e3e97 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 9 Nov 2023 09:32:24 +0000 Subject: [PATCH 113/550] Bump ruff from 0.1.3 to 0.1.5 (#1067) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: David Hewitt --- tests/requirements-linting.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 332398740..38eacfdbc 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -griffe==0.36.7 -pyright==1.1.332 -ruff==0.1.3 +griffe==0.36.9 +pyright==1.1.334 +ruff==0.1.5 mypy==1.6.1 From 4c2ea59573bfc2dba2bb260e49e8b5c003bfcc95 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 9 Nov 2023 12:20:27 +0000 Subject: [PATCH 114/550] Restore manylinux-compatible PGO builds (#1068) --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de0e52e7a..5a2b01049 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -511,6 +511,7 @@ jobs: - name: build initial wheel uses: PyO3/maturin-action@v1 with: + manylinux: auto args: > --release --out pgo-wheel @@ -538,6 +539,7 @@ jobs: - name: build pgo-optimized wheel uses: PyO3/maturin-action@v1 with: + manylinux: auto args: > --release --out dist From fbae08ce88deb6b78a688bd4948a812389bf1d27 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 9 Nov 2023 12:33:49 +0000 Subject: [PATCH 115/550] Bump version to 2.14.1 (#1069) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f27eb2fb2..e00130f7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.0" +version = "2.14.1" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index d2a97ec9b..f7c859534 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.0" +version = "2.14.1" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From baed94340599a03332ac856b5d32e0b7087309b4 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 9 Nov 2023 15:49:55 +0000 Subject: [PATCH 116/550] Fix invalid link in docstring (#1070) --- python/pydantic_core/_pydantic_core.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index b452d2f17..4c99a4d61 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -389,7 +389,7 @@ def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> A """ Deserialize JSON data to a Python object. - This is effectively a faster version of [`json.loads()`][json.loads]. + This is effectively a faster version of `json.loads()`. Arguments: data: The JSON data to deserialize. From a35b1829cd999628587f19b152bf4c5ab4a8101b Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 14 Nov 2023 14:12:20 +0000 Subject: [PATCH 117/550] restore pypy builds for x86_64 (#1072) --- .github/workflows/ci.yml | 117 ++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a2b01049..bb0918f8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -365,74 +365,73 @@ jobs: path: dist build: - name: build on ${{ matrix.platform || matrix.os }} (${{ matrix.target }} - ${{ matrix.manylinux || 'auto' }}) + name: build on ${{ matrix.os }} (${{ matrix.target }} - ${{ matrix.interpreter || 'all' }}${{ matrix.os == 'linux' && format(' - {0}', matrix.manylinux == 'auto' && 'manylinux' || matrix.manylinux) || '' }}) # only run on push to main and on release if: startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || contains(github.event.pull_request.labels.*.name, 'Full Build') strategy: fail-fast: false matrix: - os: [ubuntu, macos, windows] + os: [linux, macos, windows] target: [x86_64, aarch64] manylinux: [auto] include: - - os: ubuntu - platform: linux - - os: windows - ls: dir - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 pypy3.8 pypy3.9 pypy3.10 - - os: windows - ls: dir - target: i686 - python-architecture: x86 - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 - - os: windows - ls: dir - target: aarch64 - interpreter: 3.11 3.12 - - os: macos - target: aarch64 - interpreter: 3.7 3.8 3.9 pypy3.8 pypy3.9 pypy3.10 - - os: ubuntu - platform: linux + # manylinux for various platforms, plus x86_64 pypy + - os: linux + manylinux: auto target: i686 - - os: ubuntu - platform: linux + - os: linux + manylinux: auto target: aarch64 - - - os: ubuntu - platform: linux + - os: linux + manylinux: auto target: armv7 interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 - # musllinux - - os: ubuntu - platform: linux - target: x86_64 - manylinux: musllinux_1_1 - - os: ubuntu - platform: linux - target: aarch64 - manylinux: musllinux_1_1 - - os: ubuntu - platform: linux + - os: linux + manylinux: auto target: ppc64le interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 - - os: ubuntu - platform: linux + - os: linux + manylinux: auto target: s390x interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 - exclude: - # Optimized PGO builds for x86_64 manylinux and windows follow a different matrix, - # maybe in future maturin-action can support this automatically - - os: ubuntu - target: x86_64 + - os: linux manylinux: auto + target: x86_64 + interpreter: pypy3.7 pypy3.8 pypy3.9 pypy3.10 + + # musllinux + - os: linux + manylinux: musllinux_1_1 + target: x86_64 + - os: linux + manylinux: musllinux_1_1 + target: aarch64 + + # macos; + # all versions x86_64 + # arm pypy and older pythons which can't be run on the arm hardware for PGO + - os: macos + target: x86_64 + - os: macos + target: aarch64 + interpreter: 3.7 3.8 3.9 pypy3.8 pypy3.9 pypy3.10 + + # windows; + # x86_64 pypy builds are not PGO optimized + # i686 not supported by pypy + # aarch64 only 3.11 and up, also not PGO optimized - os: windows target: x86_64 - # Windows on arm64 only supports Python 3.11+ + interpreter: pypy3.8 pypy3.9 pypy3.10 + - os: windows + target: i686 + python-architecture: x86 + interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 - os: windows target: aarch64 + interpreter: 3.11 3.12 - runs-on: ${{ matrix.os }}-latest + runs-on: ${{ (matrix.os == 'linux' && 'ubuntu') || matrix.os }}-latest steps: - uses: actions/checkout@v4 @@ -451,13 +450,12 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux || 'auto' }} - container: ${{ matrix.container }} - args: --release --out dist --interpreter ${{ matrix.interpreter || '3.7 3.8 3.9 3.10 3.11 3.12 pypy3.7 pypy3.8 pypy3.9 pypy3.10' }} ${{ matrix.extra-build-args }} + manylinux: ${{ matrix.manylinux == 'manylinux' && 'auto' || matrix.manylinux }} + args: --release --out dist --interpreter ${{ matrix.interpreter || '3.7 3.8 3.9 3.10 3.11 3.12 pypy3.7 pypy3.8 pypy3.9 pypy3.10' }} rust-toolchain: stable docker-options: -e CI - - run: ${{ matrix.ls || 'ls -lh' }} dist/ + - run: ${{ (matrix.os == 'windows' && 'dir') || 'ls -lh' }} dist/ - run: twine check --strict dist/* @@ -473,20 +471,27 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest-xlarge] + os: [linux, windows, macos] interpreter: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] include: - - os: windows-latest + # standard runners with override for macos arm + - os: linux + runs-on: ubuntu-latest + - os: windows ls: dir + runs-on: windows-latest + - os: macos + runs-on: macos-latest-xlarge exclude: - - os: macos-latest-xlarge + # macos arm only supported from 3.10 and up + - os: macos interpreter: '3.7' - - os: macos-latest-xlarge + - os: macos interpreter: '3.8' - - os: macos-latest-xlarge + - os: macos interpreter: '3.9' - runs-on: ${{ matrix.os }} + runs-on: ${{ matrix.runs-on }} steps: - uses: actions/checkout@v4 From 439eeacd70b9f87ab36da35def6b9e9854b03672 Mon Sep 17 00:00:00 2001 From: Arthur Pastel Date: Tue, 14 Nov 2023 15:17:19 +0100 Subject: [PATCH 118/550] Bump python version for benchmarks (#1064) --- .github/workflows/codspeed.yml | 8 +++++--- Cargo.toml | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 32bc83a1e..b01d4ff7e 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.12' - uses: actions/cache@v3 id: cache-py @@ -48,10 +48,12 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 + with: + key: v1 - name: Compile pydantic-core for profiling run: | - pip install -e . --config-settings=build-args='--verbose' -v + pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v env: CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed RUSTFLAGS: "-Cprofile-generate=${{ github.workspace }}/profdata" @@ -64,7 +66,7 @@ jobs: - name: Compile pydantic-core for benchmarking run: | - pip install -e . --config-settings=build-args='--verbose' -v + pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v env: CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed RUSTFLAGS: "-Cprofile-use=${{ github.workspace }}/merged.profdata" diff --git a/Cargo.toml b/Cargo.toml index f7c859534..120b8496d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,11 @@ strip = true debug = true strip = false +[profile.codspeed] +inherits = "release" +debug = true +strip = false + [dev-dependencies] pyo3 = { version = "0.20.0", features = ["auto-initialize"] } From b18f4f95a5c852fb7e97a81a11e67bae31c58e0a Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 14 Nov 2023 14:26:12 +0000 Subject: [PATCH 119/550] bump version to 2.14.2 (#1073) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e00130f7d..244ec1cda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.1" +version = "2.14.2" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 120b8496d..e938382e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.1" +version = "2.14.2" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 73b118444cde9b14578f87fc866cf4271c8ab512 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 14 Nov 2023 17:38:48 +0000 Subject: [PATCH 120/550] uprev to 2.24.3 --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 244ec1cda..40ca98e4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.2" +version = "2.14.3" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index e938382e4..39da916f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.2" +version = "2.14.3" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 3fea8332ed300c5439acde2f92e8c66f2ff5b42b Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:46:01 -0600 Subject: [PATCH 121/550] Fix bug re `custom_init` on members of `Union` (#1076) --- src/validators/model.rs | 5 +++- tests/validators/test_model_init.py | 45 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/validators/model.rs b/src/validators/model.rs index 0299ce5d8..a571ccd9e 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -257,7 +257,10 @@ impl ModelValidator { // this work with from_attributes, and would essentially allow you to // handle init vars by adding them to the __init__ signature. if let Some(kwargs) = input.as_kwargs(py) { - return Ok(self.class.call(py, (), Some(kwargs))?); + return self + .class + .call(py, (), Some(kwargs)) + .map_err(|e| convert_err(py, e, input)); } } diff --git a/tests/validators/test_model_init.py b/tests/validators/test_model_init.py index 5521f8da4..b0c28dc86 100644 --- a/tests/validators/test_model_init.py +++ b/tests/validators/test_model_init.py @@ -479,3 +479,48 @@ def _wrap_validator(cls, v, validator, info): gc.collect() assert ref() is None + + +def test_model_custom_init_with_union() -> None: + class A: + def __init__(self, **kwargs): + assert 'a' in kwargs + self.a = kwargs.get('a') + + class B: + def __init__(self, **kwargs): + assert 'b' in kwargs + self.b = kwargs.get('b') + + schema = { + 'type': 'union', + 'choices': [ + { + 'type': 'model', + 'cls': A, + 'schema': { + 'type': 'model-fields', + 'fields': {'a': {'type': 'model-field', 'schema': {'type': 'bool'}}}, + 'model_name': 'A', + }, + 'custom_init': True, + 'ref': '__main__.A:4947206928', + }, + { + 'type': 'model', + 'cls': B, + 'schema': { + 'type': 'model-fields', + 'fields': {'b': {'type': 'model-field', 'schema': {'type': 'bool'}}}, + 'model_name': 'B', + }, + 'custom_init': True, + 'ref': '__main__.B:4679932848', + }, + ], + } + + validator = SchemaValidator(schema) + + assert validator.validate_python({'a': False}).a is False + assert validator.validate_python({'b': True}).b is True From 203b395fe0d09d8c30c37462c46192b96e6dac00 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 16 Nov 2023 10:14:31 -0600 Subject: [PATCH 122/550] Fix validation of `Literal` from JSON keys when used as `dict` key (#1075) Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com> --- src/validators/literal.rs | 16 +++++++++++++--- tests/test.rs | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/validators/literal.rs b/src/validators/literal.rs index c9a846695..686920cca 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -9,7 +9,7 @@ use pyo3::{intern, PyTraverseError, PyVisit}; use crate::build_tools::{py_schema_err, py_schema_error_type}; use crate::errors::{ErrorType, ValError, ValResult}; -use crate::input::Input; +use crate::input::{Input, ValidationMatch}; use crate::py_gc::PyGcTraverse; use crate::tools::SchemaDict; @@ -116,8 +116,18 @@ impl LiteralLookup { } } if let Some(expected_strings) = &self.expected_str { - // dbg!(expected_strings); - if let Ok(either_str) = input.exact_str() { + let validation_result = if input.is_python() { + input.exact_str() + } else { + // Strings coming from JSON are treated as "strict" but not "exact" for reasons + // of parsing types like UUID; see the implementation of `validate_str` for Json + // inputs for justification. We might change that eventually, but for now we need + // to work around this when loading from JSON + // V3 TODO: revisit making this "exact" for JSON inputs + input.validate_str(true, false).map(ValidationMatch::into_inner) + }; + + if let Ok(either_str) = validation_result { let cow = either_str.as_cow()?; if let Some(id) = expected_strings.get(cow.as_ref()) { return Ok(Some((input, &self.values[*id]))); diff --git a/tests/test.rs b/tests/test.rs index 348520435..4b0918d40 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use _pydantic_core::SchemaSerializer; + use _pydantic_core::{SchemaSerializer, SchemaValidator}; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -86,4 +86,35 @@ a = A() assert_eq!(serialized, b"{\"b\":\"b\"}"); }); } + + #[test] + fn test_literal_schema() { + Python::with_gil(|py| { + let code = r#" +schema = { + "type": "dict", + "keys_schema": { + "type": "literal", + "expected": ["a", "b"], + }, + "values_schema": { + "type": "str", + }, + "strict": False, +} +json_input = '{"a": "something"}' + "#; + let locals = PyDict::new(py); + py.run(code, None, Some(locals)).unwrap(); + let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); + let json_input: &PyAny = locals.get_item("json_input").unwrap().unwrap().extract().unwrap(); + let binding = SchemaValidator::py_new(py, schema, None) + .unwrap() + .validate_json(py, json_input, None, None, None) + .unwrap(); + let validation_result: &PyAny = binding.extract(py).unwrap(); + let repr = format!("{}", validation_result.repr().unwrap()); + assert_eq!(repr, "{'a': 'something'}"); + }); + } } From 2f0c8c27fe8475fe55ca4326aee8ba462ad6fb4e Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 17 Nov 2023 09:38:38 +0000 Subject: [PATCH 123/550] Update ci for Rust 1.74 (#1079) --- Cargo.toml | 4 ++++ Makefile | 12 +++++------- tests/test.rs | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 39da916f5..861473953 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,3 +75,7 @@ pyo3 = { version = "0.20.0", features = ["auto-initialize"] } version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy pyo3-build-config = { version = "0.20.0" } + +[lints.clippy] +dbg_macro = "warn" +print_stdout = "warn" diff --git a/Makefile b/Makefile index d9e0d0e0a..e477cf767 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .DEFAULT_GOAL := all -sources = python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py +sources = python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py mypy-stubtest = python -m mypy.stubtest pydantic_core._pydantic_core --allowlist .mypy-stubtest-allowlist @@ -90,14 +90,14 @@ build-wasm: .PHONY: format format: - ruff --fix $(sources) - ruff format $(sources) + ruff --fix $(sources) + ruff format $(sources) cargo fmt .PHONY: lint-python lint-python: - ruff $(sources) - ruff format --check $(sources) + ruff $(sources) + ruff format --check $(sources) $(mypy-stubtest) griffe dump -f -d google -LWARNING -o/dev/null python/pydantic_core @@ -109,8 +109,6 @@ lint-rust: cargo clippy --tests -- \ -D warnings \ -W clippy::pedantic \ - -W clippy::dbg_macro \ - -W clippy::print_stdout \ -A clippy::cast-possible-truncation \ -A clippy::cast-possible-wrap \ -A clippy::cast-precision-loss \ diff --git a/tests/test.rs b/tests/test.rs index 4b0918d40..e0b76a4d8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -22,7 +22,7 @@ mod tests { // 'type': 'function-wrap', // 'function': lambda: None, // }, - let code = r#"{ + let code = r"{ 'type': 'definitions', 'schema': {'type': 'definition-ref', 'schema_ref': 'C-ref'}, 'definitions': [ @@ -44,7 +44,7 @@ mod tests { }, }, ] - }"#; + }"; let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap(); SchemaSerializer::py_new(py, schema, None).unwrap(); }); From ca83d24736816ea9e862b7804b32f63e1312bb06 Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:49:11 +0100 Subject: [PATCH 124/550] Fix validation of negative floats when using `multiple_of` (#1077) Co-authored-by: David Hewitt --- src/validators/float.rs | 2 +- tests/validators/test_decimal.py | 1 + tests/validators/test_float.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/validators/float.rs b/src/validators/float.rs index b72ffafc0..126e539c2 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -109,7 +109,7 @@ impl Validator for ConstrainedFloatValidator { } if let Some(multiple_of) = self.multiple_of { let rem = float % multiple_of; - let threshold = float / 1e9; + let threshold = float.abs() / 1e9; if rem.abs() > threshold && (rem - multiple_of).abs() > threshold { return Err(ValError::new( ErrorType::MultipleOf { diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index b9fabeaed..69cd52738 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -188,6 +188,7 @@ def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_va (0.1, 1, None), (0.1, 1.0, None), (0.1, int(5e10), None), + (2.0, -2.0, None), ], ids=repr, ) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 4e3bda0c4..56c03d40e 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -121,6 +121,7 @@ def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_valu (0.1, 1, None), (0.1, 1.0, None), (0.1, int(5e10), None), + (2.0, -2.0, None), ], ids=repr, ) From 223aa278499481c6830e3c4a3dd6a1d75a9e15f2 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:19:05 -0600 Subject: [PATCH 125/550] uprev: 2.14.4 (#1081) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 40ca98e4d..249fdfbf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,7 +321,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.3" +version = "2.14.4" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 861473953..87a2f2282 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.3" +version = "2.14.4" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 14c2971729b0f8354ca7da8894b7371146ce30cd Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 17 Nov 2023 19:56:55 +0000 Subject: [PATCH 126/550] Fix bug with UUID validation from json with a wrap validator (#1080) --- src/input/input_abstract.rs | 2 +- src/validators/dataclass.rs | 3 ++- src/validators/uuid.rs | 8 +++++--- tests/validators/test_dataclasses.py | 21 +++++++++++++++++++++ tests/validators/test_uuid.py | 15 ++++++++++++++- 5 files changed, 43 insertions(+), 6 deletions(-) diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index ba6fbd0a1..765ede739 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -14,7 +14,7 @@ use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta}; use super::return_enums::{EitherBytes, EitherInt, EitherString}; use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch}; -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum InputType { Python, Json, diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 1646f5ea6..4140b1b26 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -8,6 +8,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::input::InputType; use crate::input::{BorrowInput, GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; @@ -535,7 +536,7 @@ impl Validator for DataclassValidator { } else { Ok(input.to_object(py)) } - } else if state.strict_or(self.strict) && input.is_python() { + } else if state.strict_or(self.strict) && state.extra().input_type == InputType::Python { Err(ValError::new( ErrorType::DataclassExactType { class_name: self.get_name().to_string(), diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 9e4ce9fb5..3324d5ea1 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -9,6 +9,7 @@ use uuid::Uuid; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; +use crate::input::InputType; use crate::tools::SchemaDict; use super::model::create_class; @@ -108,7 +109,7 @@ impl Validator for UuidValidator { } } Ok(py_input.to_object(py)) - } else if state.strict_or(self.strict) && input.is_python() { + } else if state.strict_or(self.strict) && state.extra().input_type == InputType::Python { Err(ValError::new( ErrorType::IsInstanceOf { class: class.name().unwrap_or("UUID").to_string(), @@ -118,8 +119,9 @@ impl Validator for UuidValidator { )) } else { // In python mode this is a coercion, in JSON mode we treat a UUID string as an - // exact match - if input.is_python() { + // exact match. + // TODO V3: we might want to remove the JSON special case + if state.extra().input_type == InputType::Python { state.floor_exactness(Exactness::Lax); } let uuid = self.get_uuid(input)?; diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 9c2b93dfe..94bf7fd65 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1510,6 +1510,27 @@ def test_dataclass_json(): ] +def test_dataclass_wrap_json(): + # https://github.com/pydantic/pydantic/issues/8147 + schema = core_schema.no_info_wrap_validator_function( + lambda v, handler: handler(v), + core_schema.dataclass_schema( + FooDataclass, + core_schema.dataclass_args_schema( + 'FooDataclass', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), + ], + ), + ['a', 'b'], + ), + ) + v = SchemaValidator(schema) + assert v.validate_json('{"a": "hello", "b": true}') == FooDataclass(a='hello', b=True) + assert v.validate_json('{"a": "hello", "b": true}', strict=True) == FooDataclass(a='hello', b=True) + + @pytest.mark.xfail( condition=platform.python_implementation() == 'PyPy', reason='https://foss.heptapod.net/pypy/pypy/-/issues/3899' ) diff --git a/tests/validators/test_uuid.py b/tests/validators/test_uuid.py index 9afe2d2b8..9b4ef60cf 100644 --- a/tests/validators/test_uuid.py +++ b/tests/validators/test_uuid.py @@ -4,7 +4,7 @@ import pytest -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson @@ -197,3 +197,16 @@ def test_uuid_copy(): assert repr(output) == "UUID('a6cc5730-2261-11ee-9c43-2eb5a363657c')" assert c == output assert isinstance(output, UUID) + + +def test_uuid_wrap_json(): + # https://github.com/pydantic/pydantic/issues/8147 + schema = core_schema.no_info_wrap_validator_function(lambda v, handler: handler(v), core_schema.uuid_schema()) + v = SchemaValidator(schema) + + assert v.validate_python(UUID('a6cc5730-2261-11ee-9c43-2eb5a363657c'), strict=True) == UUID( + 'a6cc5730-2261-11ee-9c43-2eb5a363657c' + ) + assert v.validate_json('"a6cc5730-2261-11ee-9c43-2eb5a363657c"', strict=True) == UUID( + 'a6cc5730-2261-11ee-9c43-2eb5a363657c' + ) From 2bb53d869b034ee6d6496bdd104e8377bc71c722 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 20 Nov 2023 15:52:43 +0000 Subject: [PATCH 127/550] move all clippy lints into `[lints.clippy]` table (#1083) Co-authored-by: David Hewitt --- Cargo.toml | 27 +++++++++++++++++++++++++++ Makefile | 27 +-------------------------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 87a2f2282..512d98908 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,3 +79,30 @@ pyo3-build-config = { version = "0.20.0" } [lints.clippy] dbg_macro = "warn" print_stdout = "warn" + +# in general we lint against the pedantic group, but we will whitelist +# certain lints which we don't want to enforce (for now) +pedantic = { level = "warn", priority = -1 } +cast_possible_truncation = "allow" +cast_possible_wrap = "allow" +cast_precision_loss = "allow" +cast_sign_loss = "allow" +doc_markdown = "allow" +float_cmp = "allow" +fn_params_excessive_bools = "allow" +if_not_else = "allow" +manual_let_else = "allow" +match_bool = "allow" +match_same_arms = "allow" +missing_errors_doc = "allow" +missing_panics_doc = "allow" +module_name_repetitions = "allow" +must_use_candidate = "allow" +needless_pass_by_value = "allow" +similar_names = "allow" +single_match_else = "allow" +struct_excessive_bools = "allow" +too_many_lines = "allow" +unnecessary_wraps = "allow" +unused_self = "allow" +used_underscore_binding = "allow" diff --git a/Makefile b/Makefile index e477cf767..e1f0bc607 100644 --- a/Makefile +++ b/Makefile @@ -106,32 +106,7 @@ lint-rust: cargo fmt --version cargo fmt --all -- --check cargo clippy --version - cargo clippy --tests -- \ - -D warnings \ - -W clippy::pedantic \ - -A clippy::cast-possible-truncation \ - -A clippy::cast-possible-wrap \ - -A clippy::cast-precision-loss \ - -A clippy::cast-sign-loss \ - -A clippy::doc-markdown \ - -A clippy::float-cmp \ - -A clippy::fn-params-excessive-bools \ - -A clippy::if-not-else \ - -A clippy::manual-let-else \ - -A clippy::match-bool \ - -A clippy::match-same-arms \ - -A clippy::missing-errors-doc \ - -A clippy::missing-panics-doc \ - -A clippy::module-name-repetitions \ - -A clippy::must-use-candidate \ - -A clippy::needless-pass-by-value \ - -A clippy::similar-names \ - -A clippy::single-match-else \ - -A clippy::struct-excessive-bools \ - -A clippy::too-many-lines \ - -A clippy::unnecessary-wraps \ - -A clippy::unused-self \ - -A clippy::used-underscore-binding + cargo clippy --tests -- -D warnings .PHONY: lint lint: lint-python lint-rust From be9c21c9fc64b5341c9e1fb4674239b8db9709db Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 21 Nov 2023 07:17:23 +0000 Subject: [PATCH 128/550] Remove lifetime from errors (#1084) --- src/errors/line_error.rs | 83 ++++++++++++------------------ src/errors/mod.rs | 2 +- src/errors/validation_exception.rs | 20 +++---- src/errors/value_exception.rs | 7 +-- src/input/datetime.rs | 10 ++-- src/input/input_abstract.rs | 22 ++++---- src/input/input_json.rs | 46 +++++++++++------ src/input/input_python.rs | 22 +++++--- src/input/input_string.rs | 20 +++---- src/input/return_enums.rs | 52 +++++++++---------- src/input/shared.rs | 14 ++--- src/lookup_key.rs | 19 +++---- src/validators/any.rs | 2 +- src/validators/arguments.rs | 2 +- src/validators/bool.rs | 2 +- src/validators/bytes.rs | 4 +- src/validators/call.rs | 2 +- src/validators/callable.rs | 2 +- src/validators/chain.rs | 2 +- src/validators/custom_error.rs | 5 +- src/validators/dataclass.rs | 21 ++++---- src/validators/date.rs | 4 +- src/validators/datetime.rs | 4 +- src/validators/decimal.rs | 18 ++----- src/validators/definitions.rs | 4 +- src/validators/dict.rs | 21 ++++---- src/validators/float.rs | 4 +- src/validators/frozenset.rs | 2 +- src/validators/function.rs | 24 ++++----- src/validators/generator.rs | 2 +- src/validators/int.rs | 4 +- src/validators/is_instance.rs | 2 +- src/validators/is_subclass.rs | 2 +- src/validators/json.rs | 4 +- src/validators/json_or_python.rs | 2 +- src/validators/lax_or_strict.rs | 2 +- src/validators/list.rs | 2 +- src/validators/literal.rs | 4 +- src/validators/mod.rs | 12 ++--- src/validators/model.rs | 10 ++-- src/validators/model_fields.rs | 22 ++++---- src/validators/none.rs | 2 +- src/validators/nullable.rs | 2 +- src/validators/set.rs | 2 +- src/validators/string.rs | 4 +- src/validators/time.rs | 2 +- src/validators/timedelta.rs | 2 +- src/validators/tuple.rs | 8 +-- src/validators/typed_dict.rs | 17 +++--- src/validators/union.rs | 28 +++++----- src/validators/url.rs | 24 ++++----- src/validators/uuid.rs | 6 +-- src/validators/with_default.rs | 8 +-- 53 files changed, 294 insertions(+), 319 deletions(-) diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 3ee4c7894..f5fa11b70 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -9,44 +9,54 @@ use crate::input::Input; use super::location::{LocItem, Location}; use super::types::ErrorType; -pub type ValResult<'a, T> = Result>; +pub type ValResult = Result; + +pub trait AsErrorValue { + fn as_error_value(&self) -> InputValue; +} + +impl<'a, T: Input<'a>> AsErrorValue for T { + fn as_error_value(&self) -> InputValue { + Input::as_error_value(self) + } +} #[cfg_attr(debug_assertions, derive(Debug))] -pub enum ValError<'a> { - LineErrors(Vec>), +pub enum ValError { + LineErrors(Vec), InternalErr(PyErr), Omit, UseDefault, } -impl<'a> From for ValError<'a> { +impl From for ValError { fn from(py_err: PyErr) -> Self { Self::InternalErr(py_err) } } -impl<'a> From> for ValError<'a> { +impl From> for ValError { fn from(py_downcast: PyDowncastError) -> Self { Self::InternalErr(PyTypeError::new_err(py_downcast.to_string())) } } -impl<'a> From>> for ValError<'a> { - fn from(line_errors: Vec>) -> Self { +impl From> for ValError { + fn from(line_errors: Vec) -> Self { Self::LineErrors(line_errors) } } -impl<'a> ValError<'a> { - pub fn new(error_type: ErrorType, input: &'a impl Input<'a>) -> ValError<'a> { +impl ValError { + pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValError { Self::LineErrors(vec![ValLineError::new(error_type, input)]) } - pub fn new_with_loc(error_type: ErrorType, input: &'a impl Input<'a>, loc: impl Into) -> ValError<'a> { + pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into) -> ValError { Self::LineErrors(vec![ValLineError::new_with_loc(error_type, input, loc)]) } - pub fn new_custom_input(error_type: ErrorType, input_value: InputValue<'a>) -> ValError<'a> { + pub fn new_custom_input(error_type: ErrorType, input_value: InputValue) -> ValError { Self::LineErrors(vec![ValLineError::new_custom_input(error_type, input_value)]) } @@ -62,31 +72,21 @@ impl<'a> ValError<'a> { other => other, } } - - /// a bit like clone but change the lifetime to match py - pub fn into_owned(self, py: Python<'_>) -> ValError<'_> { - match self { - ValError::LineErrors(errors) => errors.into_iter().map(|e| e.into_owned(py)).collect::>().into(), - ValError::InternalErr(err) => ValError::InternalErr(err), - ValError::Omit => ValError::Omit, - ValError::UseDefault => ValError::UseDefault, - } - } } /// A `ValLineError` is a single error that occurred during validation which is converted to a `PyLineError` /// to eventually form a `ValidationError`. /// I don't like the name `ValLineError`, but it's the best I could come up with (for now). #[cfg_attr(debug_assertions, derive(Debug))] -pub struct ValLineError<'a> { +pub struct ValLineError { pub error_type: ErrorType, // location is reversed so that adding an "outer" location item is pushing, it's reversed before showing to the user pub location: Location, - pub input_value: InputValue<'a>, + pub input_value: InputValue, } -impl<'a> ValLineError<'a> { - pub fn new(error_type: ErrorType, input: &'a impl Input<'a>) -> ValLineError<'a> { +impl ValLineError { + pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValLineError { Self { error_type, input_value: input.as_error_value(), @@ -94,7 +94,7 @@ impl<'a> ValLineError<'a> { } } - pub fn new_with_loc(error_type: ErrorType, input: &'a impl Input<'a>, loc: impl Into) -> ValLineError<'a> { + pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into) -> ValLineError { Self { error_type, input_value: input.as_error_value(), @@ -102,7 +102,7 @@ impl<'a> ValLineError<'a> { } } - pub fn new_with_full_loc(error_type: ErrorType, input: &'a impl Input<'a>, location: Location) -> ValLineError<'a> { + pub fn new_with_full_loc(error_type: ErrorType, input: &impl AsErrorValue, location: Location) -> ValLineError { Self { error_type, input_value: input.as_error_value(), @@ -110,7 +110,7 @@ impl<'a> ValLineError<'a> { } } - pub fn new_custom_input(error_type: ErrorType, input_value: InputValue<'a>) -> ValLineError<'a> { + pub fn new_custom_input(error_type: ErrorType, input_value: InputValue) -> ValLineError { Self { error_type, input_value, @@ -130,35 +130,20 @@ impl<'a> ValLineError<'a> { self.error_type = error_type; self } - - /// a bit like clone but change the lifetime to match py, used by ValError.into_owned above - pub fn into_owned(self, py: Python<'_>) -> ValLineError<'_> { - ValLineError { - error_type: self.error_type, - input_value: match self.input_value { - InputValue::PyAny(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), - InputValue::JsonInput(input) => InputValue::JsonInput(input), - InputValue::String(input) => InputValue::PyAny(input.to_object(py).into_ref(py)), - }, - location: self.location, - } - } } #[cfg_attr(debug_assertions, derive(Debug))] #[derive(Clone)] -pub enum InputValue<'a> { - PyAny(&'a PyAny), - JsonInput(JsonValue), - String(&'a str), +pub enum InputValue { + Python(PyObject), + Json(JsonValue), } -impl<'a> ToPyObject for InputValue<'a> { +impl ToPyObject for InputValue { fn to_object(&self, py: Python) -> PyObject { match self { - Self::PyAny(input) => input.into_py(py), - Self::JsonInput(input) => input.to_object(py), - Self::String(input) => input.into_py(py), + Self::Python(input) => input.clone_ref(py), + Self::Json(input) => input.to_object(py), } } } diff --git a/src/errors/mod.rs b/src/errors/mod.rs index bfc5b4329..131e54177 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -6,7 +6,7 @@ mod types; mod validation_exception; mod value_exception; -pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; +pub use self::line_error::{AsErrorValue, InputValue, ValError, ValLineError, ValResult}; pub use self::location::{AsLocItem, LocItem}; pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index d616e3022..91ef60a0d 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -225,12 +225,8 @@ fn get_url_prefix(py: Python, include_url: bool) -> Option<&str> { // used to convert a validation error back to ValError for wrap functions impl ValidationError { - pub(crate) fn into_val_error(self, py: Python<'_>) -> ValError<'_> { - self.line_errors - .into_iter() - .map(|e| e.into_val_line_error(py)) - .collect::>() - .into() + pub(crate) fn into_val_error(self) -> ValError { + self.line_errors.into_iter().map(Into::into).collect::>().into() } } @@ -416,7 +412,7 @@ pub struct PyLineError { input_value: PyObject, } -impl<'a> IntoPy for ValLineError<'a> { +impl IntoPy for ValLineError { fn into_py(self, py: Python<'_>) -> PyLineError { PyLineError { error_type: self.error_type, @@ -426,13 +422,13 @@ impl<'a> IntoPy for ValLineError<'a> { } } -impl PyLineError { +impl From for ValLineError { /// Used to extract line errors from a validation error for wrap functions - fn into_val_line_error(self, py: Python<'_>) -> ValLineError<'_> { + fn from(other: PyLineError) -> ValLineError { ValLineError { - error_type: self.error_type, - location: self.location, - input_value: InputValue::PyAny(self.input_value.into_ref(py)), + error_type: other.error_type, + location: other.location, + input_value: InputValue::Python(other.input_value), } } } diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index 7bc7e5227..a88610eef 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -2,9 +2,10 @@ use pyo3::exceptions::{PyException, PyValueError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; -use crate::input::{Input, InputType}; +use crate::input::InputType; use crate::tools::extract_i64; +use super::line_error::AsErrorValue; use super::{ErrorType, ValError}; #[pyclass(extends=PyException, module="pydantic_core._pydantic_core")] @@ -105,7 +106,7 @@ impl PydanticCustomError { } impl PydanticCustomError { - pub fn into_val_error<'a>(self, input: &'a impl Input<'a>) -> ValError<'a> { + pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError { let error_type = ErrorType::CustomError { error_type: self.error_type, message_template: self.message_template, @@ -184,7 +185,7 @@ impl PydanticKnownError { } impl PydanticKnownError { - pub fn into_val_error<'a>(self, input: &'a impl Input<'a>) -> ValError<'a> { + pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError { ValError::new(self.error_type, input) } } diff --git a/src/input/datetime.rs b/src/input/datetime.rs index b89b76c1e..0a3cdf929 100644 --- a/src/input/datetime.rs +++ b/src/input/datetime.rs @@ -286,7 +286,7 @@ impl<'a> EitherDateTime<'a> { } } -pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<'a, EitherDate<'a>> { +pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult> { match Date::parse_bytes(bytes) { Ok(date) => Ok(date.into()), Err(err) => Err(ValError::new( @@ -303,7 +303,7 @@ pub fn bytes_as_time<'a>( input: &'a impl Input<'a>, bytes: &[u8], microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, -) -> ValResult<'a, EitherTime<'a>> { +) -> ValResult> { match Time::parse_bytes_with_config( bytes, &TimeConfig { @@ -326,7 +326,7 @@ pub fn bytes_as_datetime<'a, 'b>( input: &'a impl Input<'a>, bytes: &'b [u8], microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, -) -> ValResult<'a, EitherDateTime<'a>> { +) -> ValResult> { match DateTime::parse_bytes_with_config( bytes, &TimeConfig { @@ -455,7 +455,7 @@ pub fn float_as_time<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult int_as_time(input, timestamp.floor() as i64, microseconds.round() as u32) } -fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError<'a> { +fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError { ValError::new( ErrorType::TimeDeltaParsing { error: Cow::Borrowed(err.get_documentation().unwrap_or_default()), @@ -469,7 +469,7 @@ pub fn bytes_as_timedelta<'a, 'b>( input: &'a impl Input<'a>, bytes: &'b [u8], microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior, -) -> ValResult<'a, EitherTimedelta<'a>> { +) -> ValResult> { match Duration::parse_bytes_with_config( bytes, &TimeConfig { diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 765ede739..ceb0495a9 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -49,7 +49,7 @@ impl TryFrom<&str> for InputType { /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { - fn as_error_value(&'a self) -> InputValue<'a>; + fn as_error_value(&self) -> InputValue; fn identity(&self) -> Option { None @@ -85,11 +85,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { false } - fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>>; + fn validate_args(&'a self) -> ValResult>; - fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>; + fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult>; - fn parse_json(&'a self) -> ValResult<'a, JsonValue>; + fn parse_json(&'a self) -> ValResult; fn validate_str( &'a self, @@ -99,9 +99,9 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { fn validate_bytes(&'a self, strict: bool) -> ValResult>>; - fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch>; + fn validate_bool(&self, strict: bool) -> ValResult>; - fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; + fn validate_int(&'a self, strict: bool) -> ValResult>>; fn exact_int(&'a self) -> ValResult> { self.validate_int(true).and_then(|val_match| { @@ -121,7 +121,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { }) } - fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>>; + fn validate_float(&'a self, strict: bool) -> ValResult>>; fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> { if strict { @@ -230,15 +230,11 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { ) -> ValResult>; } -/// The problem to solve here is that iterating a `StringMapping` returns an owned -/// `StringMapping`, but all the other iterators return references. By introducing +/// The problem to solve here is that iterating collections often returns owned +/// values, but inputs are usually taken by reference. By introducing /// this trait we abstract over whether the return value from the iterator is owned /// or borrowed; all we care about is that we can borrow it again with `borrow_input` /// for some lifetime 'a. -/// -/// This lifetime `'a` is shorter than the original lifetime `'data` of the input, -/// which is only a problem in error branches. To resolve we have to call `into_owned` -/// to extend out the lifetime to match the original input. pub trait BorrowInput { type Input<'a>: Input<'a> where diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 0411a25d6..3e94adad6 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -31,10 +31,16 @@ impl AsLocItem for JsonValue { } } +impl AsLocItem for &JsonValue { + fn as_loc_item(&self) -> LocItem { + AsLocItem::as_loc_item(*self) + } +} + impl<'a> Input<'a> for JsonValue { - fn as_error_value(&'a self) -> InputValue<'a> { + fn as_error_value(&self) -> InputValue { // cloning JsonValue is cheap due to use of Arc - InputValue::JsonInput(self.clone()) + InputValue::Json(self.clone()) } fn is_none(&self) -> bool { @@ -54,7 +60,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { + fn validate_args(&'a self) -> ValResult> { match self { JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), JsonValue::Array(array) => Ok(JsonArgs::new(Some(array), None).into()), @@ -62,7 +68,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { + fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult> { match self { JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object)).into()), _ => { @@ -78,7 +84,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + fn parse_json(&'a self) -> ValResult { match self { JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)), _ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), @@ -118,7 +124,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { + fn validate_bool(&self, strict: bool) -> ValResult> { match self { JsonValue::Bool(b) => Ok(ValidationMatch::exact(*b)), JsonValue::Str(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax), @@ -134,7 +140,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_int(&'a self, strict: bool) -> ValResult>> { match self { JsonValue::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))), JsonValue::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))), @@ -145,7 +151,7 @@ impl<'a> Input<'a> for JsonValue { } } - fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_float(&'a self, strict: bool) -> ValResult>> { match self { JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))), JsonValue::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))), @@ -326,10 +332,18 @@ impl AsLocItem for String { } } +impl AsLocItem for &String { + fn as_loc_item(&self) -> LocItem { + AsLocItem::as_loc_item(*self) + } +} + /// Required for JSON Object keys so the string can behave like an Input impl<'a> Input<'a> for String { - fn as_error_value(&'a self) -> InputValue<'a> { - InputValue::String(self) + fn as_error_value(&self) -> InputValue { + // Justification for the clone: this is on the error pathway and we are generally ok + // with errors having a performance penalty + InputValue::Json(JsonValue::Str(self.clone())) } fn as_kwargs(&'a self, _py: Python<'a>) -> Option<&'a PyDict> { @@ -337,12 +351,12 @@ impl<'a> Input<'a> for String { } #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { + fn validate_args(&'a self) -> ValResult> { Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) } #[cfg_attr(has_coverage_attribute, coverage(off))] - fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { + fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult> { let class_name = class_name.to_string(); Err(ValError::new( ErrorType::DataclassType { @@ -353,7 +367,7 @@ impl<'a> Input<'a> for String { )) } - fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + fn parse_json(&'a self) -> ValResult { JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) } @@ -374,18 +388,18 @@ impl<'a> Input<'a> for String { Ok(ValidationMatch::strict(self.as_bytes().into())) } - fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { + fn validate_bool(&self, _strict: bool) -> ValResult> { str_as_bool(self, self).map(ValidationMatch::lax) } - fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_int(&'a self, _strict: bool) -> ValResult>> { match self.parse() { Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))), Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), } } - fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_float(&'a self, _strict: bool) -> ValResult>> { str_as_float(self, self).map(ValidationMatch::lax) } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 90d2c2a8b..ba42abbb8 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -106,9 +106,15 @@ impl AsLocItem for PyAny { } } +impl AsLocItem for &'_ PyAny { + fn as_loc_item(&self) -> LocItem { + AsLocItem::as_loc_item(*self) + } +} + impl<'a> Input<'a> for PyAny { - fn as_error_value(&'a self) -> InputValue<'a> { - InputValue::PyAny(self) + fn as_error_value(&self) -> InputValue { + InputValue::Python(self.into()) } fn identity(&self) -> Option { @@ -154,7 +160,7 @@ impl<'a> Input<'a> for PyAny { self.is_callable() } - fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { + fn validate_args(&'a self) -> ValResult> { if let Ok(dict) = self.downcast::() { Ok(PyArgs::new(None, Some(dict)).into()) } else if let Ok(args_kwargs) = self.extract::() { @@ -170,7 +176,7 @@ impl<'a> Input<'a> for PyAny { } } - fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult<'a, GenericArguments<'a>> { + fn validate_dataclass_args(&'a self, class_name: &str) -> ValResult> { if let Ok(dict) = self.downcast::() { Ok(PyArgs::new(None, Some(dict)).into()) } else if let Ok(args_kwargs) = self.extract::() { @@ -189,7 +195,7 @@ impl<'a> Input<'a> for PyAny { } } - fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + fn parse_json(&'a self) -> ValResult { let bytes = if let Ok(py_bytes) = self.downcast::() { py_bytes.as_bytes() } else if let Ok(py_str) = self.downcast::() { @@ -296,7 +302,7 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::BytesType, self)) } - fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch> { + fn validate_bool(&self, strict: bool) -> ValResult> { if let Ok(bool) = self.downcast::() { return Ok(ValidationMatch::exact(bool.is_true())); } @@ -319,7 +325,7 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::BoolType, self)) } - fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_int(&'a self, strict: bool) -> ValResult>> { if self.is_exact_instance_of::() { return Ok(ValidationMatch::exact(EitherInt::Py(self))); } else if self.is_instance_of::() { @@ -359,7 +365,7 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::IntType, self)) } - fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_float(&'a self, strict: bool) -> ValResult>> { if let Ok(float) = self.downcast_exact::() { return Ok(ValidationMatch::exact(EitherFloat::Py(float))); } diff --git a/src/input/input_string.rs b/src/input/input_string.rs index e27ef6461..290247617 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -34,7 +34,7 @@ impl<'py> ToPyObject for StringMapping<'py> { } impl<'py> StringMapping<'py> { - pub fn new_key(py_key: &'py PyAny) -> ValResult<'py, StringMapping> { + pub fn new_key(py_key: &'py PyAny) -> ValResult { if let Ok(py_str) = py_key.downcast::() { Ok(Self::String(py_str)) } else { @@ -42,7 +42,7 @@ impl<'py> StringMapping<'py> { } } - pub fn new_value(py_value: &'py PyAny) -> ValResult<'py, Self> { + pub fn new_value(py_value: &'py PyAny) -> ValResult { if let Ok(py_str) = py_value.downcast::() { Ok(Self::String(py_str)) } else if let Ok(value) = py_value.downcast::() { @@ -63,10 +63,10 @@ impl AsLocItem for StringMapping<'_> { } impl<'a> Input<'a> for StringMapping<'a> { - fn as_error_value(&'a self) -> InputValue<'a> { + fn as_error_value(&self) -> InputValue { match self { Self::String(s) => s.as_error_value(), - Self::Mapping(d) => InputValue::PyAny(d), + Self::Mapping(d) => d.as_error_value(), } } @@ -74,19 +74,19 @@ impl<'a> Input<'a> for StringMapping<'a> { None } - fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> { + fn validate_args(&'a self) -> ValResult> { // do we want to support this? Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) } - fn validate_dataclass_args(&'a self, _dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>> { + fn validate_dataclass_args(&'a self, _dataclass_name: &str) -> ValResult> { match self { StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), StringMapping::Mapping(m) => Ok(GenericArguments::StringMapping(m)), } } - fn parse_json(&'a self) -> ValResult<'a, JsonValue> { + fn parse_json(&'a self) -> ValResult { match self { Self::String(s) => { let str = py_string_str(s)?; @@ -114,14 +114,14 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch> { + fn validate_bool(&self, _strict: bool) -> ValResult> { match self { Self::String(s) => str_as_bool(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)), } } - fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_int(&'a self, _strict: bool) -> ValResult>> { match self { Self::String(s) => match py_string_str(s)?.parse() { Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))), @@ -131,7 +131,7 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch>> { + fn validate_float(&'a self, _strict: bool) -> ValResult>> { match self { Self::String(s) => str_as_float(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)), diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 56c7098df..fa70880ca 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -153,7 +153,7 @@ impl<'a, INPUT: Input<'a>> MaxLengthCheck<'a, INPUT> { } } - fn incr(&mut self) -> ValResult<'a, ()> { + fn incr(&mut self) -> ValResult<()> { if let Some(max_length) = self.max_length { self.current_length += 1; if self.current_length > max_length { @@ -193,7 +193,7 @@ fn validate_iter_to_vec<'a, 's>( mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>, validator: &'s CombinedValidator, state: &mut ValidationState, -) -> ValResult<'a, Vec> { +) -> ValResult> { let mut output: Vec = Vec::with_capacity(capacity); let mut errors: Vec = Vec::new(); for (index, item_result) in iter.enumerate() { @@ -259,7 +259,7 @@ fn validate_iter_to_set<'a, 's>( max_length: Option, validator: &'s CombinedValidator, state: &mut ValidationState, -) -> ValResult<'a, ()> { +) -> ValResult<()> { let mut errors: Vec = Vec::new(); for (index, item_result) in iter.enumerate() { let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?; @@ -303,7 +303,7 @@ fn no_validator_iter_to_vec<'a, 's>( input: &'a (impl Input<'a> + 'a), iter: impl Iterator + 'a)>>, mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>, -) -> ValResult<'a, Vec> { +) -> ValResult> { iter.enumerate() .map(|(index, result)| { let v = result.map_err(|e| any_next_error!(py, e, input, index))?; @@ -348,7 +348,7 @@ impl<'a> GenericIterable<'a> { field_type: &'static str, validator: &'s CombinedValidator, state: &mut ValidationState, - ) -> ValResult<'a, Vec> { + ) -> ValResult> { let actual_length = self.generic_len(); let capacity = actual_length.unwrap_or(DEFAULT_CAPACITY); let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length); @@ -381,7 +381,7 @@ impl<'a> GenericIterable<'a> { field_type: &'static str, validator: &'s CombinedValidator, state: &mut ValidationState, - ) -> ValResult<'a, ()> { + ) -> ValResult<()> { macro_rules! validate_set { ($iter:expr) => { validate_iter_to_set(py, set, $iter, input, field_type, max_length, validator, state) @@ -406,7 +406,7 @@ impl<'a> GenericIterable<'a> { input: &'a impl Input<'a>, field_type: &'static str, max_length: Option, - ) -> ValResult<'a, Vec> { + ) -> ValResult> { let actual_length = self.generic_len(); let max_length_check = MaxLengthCheck::new(max_length, field_type, input, actual_length); @@ -456,13 +456,13 @@ pub struct DictGenericIterator<'py> { } impl<'py> DictGenericIterator<'py> { - pub fn new(dict: &'py PyDict) -> ValResult<'py, Self> { + pub fn new(dict: &'py PyDict) -> ValResult { Ok(Self { dict_iter: dict.iter() }) } } impl<'py> Iterator for DictGenericIterator<'py> { - type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>; + type Item = ValResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { self.dict_iter.next().map(Ok) @@ -475,7 +475,7 @@ pub struct MappingGenericIterator<'py> { iter: &'py PyIterator, } -fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) -> ValError<'py> { +fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) -> ValError { ValError::new( ErrorType::MappingType { error: py_err_string(py, err).into(), @@ -486,7 +486,7 @@ fn mapping_err<'py>(err: PyErr, py: Python<'py>, input: &'py impl Input<'py>) -> } impl<'py> MappingGenericIterator<'py> { - pub fn new(mapping: &'py PyMapping) -> ValResult<'py, Self> { + pub fn new(mapping: &'py PyMapping) -> ValResult { let py = mapping.py(); let input: &PyAny = mapping; let iter = mapping @@ -501,7 +501,7 @@ impl<'py> MappingGenericIterator<'py> { const MAPPING_TUPLE_ERROR: &str = "Mapping items must be tuples of (key, value) pairs"; impl<'py> Iterator for MappingGenericIterator<'py> { - type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>; + type Item = ValResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { Some(match self.iter.next()? { @@ -524,14 +524,14 @@ pub struct StringMappingGenericIterator<'py> { } impl<'py> StringMappingGenericIterator<'py> { - pub fn new(dict: &'py PyDict) -> ValResult<'py, Self> { + pub fn new(dict: &'py PyDict) -> ValResult { Ok(Self { dict_iter: dict.iter() }) } } impl<'py> Iterator for StringMappingGenericIterator<'py> { // key (the first member of the tuple could be a simple String) - type Item = ValResult<'py, (StringMapping<'py>, StringMapping<'py>)>; + type Item = ValResult<(StringMapping<'py>, StringMapping<'py>)>; fn next(&mut self) -> Option { match self.dict_iter.next() { @@ -558,7 +558,7 @@ pub struct AttributesGenericIterator<'py> { } impl<'py> AttributesGenericIterator<'py> { - pub fn new(py_any: &'py PyAny) -> ValResult<'py, Self> { + pub fn new(py_any: &'py PyAny) -> ValResult { Ok(Self { object: py_any, attributes_iterator: py_any.dir().into_iter(), @@ -567,7 +567,7 @@ impl<'py> AttributesGenericIterator<'py> { } impl<'py> Iterator for AttributesGenericIterator<'py> { - type Item = ValResult<'py, (&'py PyAny, &'py PyAny)>; + type Item = ValResult<(&'py PyAny, &'py PyAny)>; fn next(&mut self) -> Option { // loop until we find an attribute who's name does not start with underscore, @@ -610,7 +610,7 @@ pub struct JsonObjectGenericIterator<'py> { } impl<'py> JsonObjectGenericIterator<'py> { - pub fn new(json_object: &'py JsonObject) -> ValResult<'py, Self> { + pub fn new(json_object: &'py JsonObject) -> ValResult { Ok(Self { object_iter: json_object.iter(), }) @@ -618,7 +618,7 @@ impl<'py> JsonObjectGenericIterator<'py> { } impl<'py> Iterator for JsonObjectGenericIterator<'py> { - type Item = ValResult<'py, (&'py String, &'py JsonValue)>; + type Item = ValResult<(&'py String, &'py JsonValue)>; fn next(&mut self) -> Option { self.object_iter.next().map(|(key, value)| Ok((key, value))) @@ -670,8 +670,8 @@ impl GenericPyIterator { } } - pub fn input_as_error_value<'py>(&self, py: Python<'py>) -> InputValue<'py> { - InputValue::PyAny(self.obj.clone_ref(py).into_ref(py)) + pub fn input_as_error_value(&self, py: Python<'_>) -> InputValue { + InputValue::Python(self.obj.clone_ref(py)) } pub fn index(&self) -> usize { @@ -699,8 +699,8 @@ impl GenericJsonIterator { } } - pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> { - InputValue::JsonInput(JsonValue::Array(self.array.clone())) + pub fn input_as_error_value(&self, _py: Python<'_>) -> InputValue { + InputValue::Json(JsonValue::Array(self.array.clone())) } pub fn index(&self) -> usize { @@ -758,7 +758,7 @@ pub enum EitherString<'a> { } impl<'a> EitherString<'a> { - pub fn as_cow(&self) -> ValResult<'a, Cow> { + pub fn as_cow(&self) -> ValResult> { match self { Self::Cow(data) => Ok(data.clone()), Self::Py(py_str) => Ok(Cow::Borrowed(py_string_str(py_str)?)), @@ -800,7 +800,7 @@ impl<'a> IntoPy for EitherString<'a> { pub fn py_string_str(py_str: &PyString) -> ValResult<&str> { py_str .to_str() - .map_err(|_| ValError::new_custom_input(ErrorTypeDefaults::StringUnicode, InputValue::PyAny(py_str as &PyAny))) + .map_err(|_| ValError::new_custom_input(ErrorTypeDefaults::StringUnicode, InputValue::Python(py_str.into()))) } #[cfg_attr(debug_assertions, derive(Debug))] @@ -870,7 +870,7 @@ impl<'a> EitherInt<'a> { Ok(Self::BigInt(big_int)) } } - pub fn into_i64(self, py: Python<'a>) -> ValResult<'a, i64> { + pub fn into_i64(self, py: Python<'a>) -> ValResult { match self { EitherInt::I64(i) => Ok(i), EitherInt::U64(u) => match i64::try_from(u) { @@ -893,7 +893,7 @@ impl<'a> EitherInt<'a> { } } - pub fn as_int(&self) -> ValResult<'a, Int> { + pub fn as_int(&self) -> ValResult { match self { EitherInt::I64(i) => Ok(Int::I64(*i)), EitherInt::U64(u) => match i64::try_from(*u) { diff --git a/src/input/shared.rs b/src/input/shared.rs index 718210098..647bce8a3 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -20,7 +20,7 @@ pub fn get_enum_meta_object(py: Python) -> Py { .clone() } -pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> { +pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError { ValError::new( ErrorType::JsonInvalid { error: error.to_string(), @@ -30,7 +30,7 @@ pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> Val ) } -pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<'a, bool> { +pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult { if str == "0" || str.eq_ignore_ascii_case("f") || str.eq_ignore_ascii_case("n") @@ -52,7 +52,7 @@ pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult<'a, bo } } -pub fn int_as_bool<'a>(input: &'a impl Input<'a>, int: i64) -> ValResult<'a, bool> { +pub fn int_as_bool<'a>(input: &'a impl Input<'a>, int: i64) -> ValResult { if int == 0 { Ok(false) } else if int == 1 { @@ -82,7 +82,7 @@ fn strip_underscores(s: &str) -> Option { /// max length of the input is 4300, see /// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and /// https://github.com/python/cpython/issues/95778 for more info in that length bound -pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<'s, EitherInt<'s>> { +pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult> { let len = str.len(); if len > 4300 { Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)) @@ -106,7 +106,7 @@ pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult< } /// parse a float as a float -pub fn str_as_float<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult<'s, EitherFloat<'s>> { +pub fn str_as_float<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult> { match str.parse() { Ok(float) => Ok(EitherFloat::F64(float)), Err(_) => match strip_underscores(str).and_then(|stripped| stripped.parse().ok()) { @@ -140,7 +140,7 @@ fn strip_decimal_zeros(s: &str) -> Option<&str> { None } -pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a, EitherInt<'a>> { +pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult> { if float.is_infinite() || float.is_nan() { Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)) } else if float % 1.0 != 0.0 { @@ -152,7 +152,7 @@ pub fn float_as_int<'a>(input: &'a impl Input<'a>, float: f64) -> ValResult<'a, } } -pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult<'a, EitherInt<'a>> { +pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a PyAny) -> ValResult> { if !decimal.call_method0(intern!(py, "is_finite"))?.extract::()? { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); } diff --git a/src/lookup_key.rs b/src/lookup_key.rs index f833c00af..21b584690 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -111,7 +111,7 @@ impl LookupKey { pub fn py_get_dict_item<'data, 's>( &'s self, dict: &'data PyDict, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { + ) -> ValResult> { match self { Self::Simple { py_key, path, .. } => match dict.get_item(py_key)? { Some(value) => Ok(Some((path, value))), @@ -148,7 +148,7 @@ impl LookupKey { pub fn py_get_string_mapping_item<'data, 's>( &'s self, dict: &'data PyDict, - ) -> ValResult<'data, Option<(&'s LookupPath, StringMapping<'data>)>> { + ) -> ValResult)>> { if let Some((path, py_any)) = self.py_get_dict_item(dict)? { let value = StringMapping::new_value(py_any)?; Ok(Some((path, value))) @@ -160,7 +160,7 @@ impl LookupKey { pub fn py_get_mapping_item<'data, 's>( &'s self, dict: &'data PyMapping, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { + ) -> ValResult> { match self { Self::Simple { py_key, path, .. } => match dict.get_item(py_key) { Ok(value) => Ok(Some((path, value))), @@ -198,7 +198,7 @@ impl LookupKey { &'s self, obj: &'data PyAny, kwargs: Option<&'data PyDict>, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> { + ) -> ValResult> { match self._py_get_attr(obj, kwargs) { Ok(v) => Ok(v), Err(err) => { @@ -266,7 +266,7 @@ impl LookupKey { pub fn json_get<'data, 's>( &'s self, dict: &'data JsonObject, - ) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> { + ) -> ValResult> { match self { Self::Simple { key, path, .. } => match dict.get(key) { Some(value) => Ok(Some((path, value))), @@ -316,7 +316,7 @@ impl LookupKey { input: &'d impl Input<'d>, loc_by_alias: bool, field_name: &str, - ) -> ValLineError<'d> { + ) -> ValLineError { if loc_by_alias { let lookup_path = match self { Self::Simple { path, .. } => path, @@ -369,12 +369,7 @@ impl LookupPath { } } - pub fn apply_error_loc<'a>( - &self, - mut line_error: ValLineError<'a>, - loc_by_alias: bool, - field_name: &str, - ) -> ValLineError<'a> { + pub fn apply_error_loc(&self, mut line_error: ValLineError, loc_by_alias: bool, field_name: &str) -> ValLineError { if loc_by_alias { for path_item in self.iter().rev() { line_error = line_error.with_outer_location(path_item.clone().into()); diff --git a/src/validators/any.rs b/src/validators/any.rs index 2fad89091..244117a01 100644 --- a/src/validators/any.rs +++ b/src/validators/any.rs @@ -32,7 +32,7 @@ impl Validator for AnyValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { // in a union, Any should be preferred to doing lax coercions state.floor_exactness(Exactness::Strict); Ok(input.to_object(py)) diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 7ae65d579..0405d54e5 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -166,7 +166,7 @@ impl Validator for ArgumentsValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let args = input.validate_args()?; let mut output_args: Vec = Vec::with_capacity(self.positional_params_count); diff --git a/src/validators/bool.rs b/src/validators/bool.rs index bcd48e991..b2bb35e5d 100644 --- a/src/validators/bool.rs +++ b/src/validators/bool.rs @@ -35,7 +35,7 @@ impl Validator for BoolValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { // TODO in theory this could be quicker if we used PyBool rather than going to a bool // and back again, might be worth profiling? input diff --git a/src/validators/bytes.rs b/src/validators/bytes.rs index 78a8acb24..9dd07d7d3 100644 --- a/src/validators/bytes.rs +++ b/src/validators/bytes.rs @@ -45,7 +45,7 @@ impl Validator for BytesValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { input .validate_bytes(state.strict_or(self.strict)) .map(|m| m.unpack(state).into_py(py)) @@ -71,7 +71,7 @@ impl Validator for BytesConstrainedValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state); let len = either_bytes.len()?; diff --git a/src/validators/call.rs b/src/validators/call.rs index e0649aa53..d43df4984 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -77,7 +77,7 @@ impl Validator for CallValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let args = self.arguments_validator.validate(py, input, state)?; let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) { diff --git a/src/validators/callable.rs b/src/validators/callable.rs index 3075e182e..e05c99b12 100644 --- a/src/validators/callable.rs +++ b/src/validators/callable.rs @@ -30,7 +30,7 @@ impl Validator for CallableValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { state.floor_exactness(Exactness::Lax); match input.callable() { true => Ok(input.to_object(py)), diff --git a/src/validators/chain.rs b/src/validators/chain.rs index d8da86e30..299801dea 100644 --- a/src/validators/chain.rs +++ b/src/validators/chain.rs @@ -75,7 +75,7 @@ impl Validator for ChainValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let mut steps_iter = self.steps.iter(); let first_step = steps_iter.next().unwrap(); let value = first_step.validate(py, input, state)?; diff --git a/src/validators/custom_error.rs b/src/validators/custom_error.rs index 4ea31aa5a..ec0af567a 100644 --- a/src/validators/custom_error.rs +++ b/src/validators/custom_error.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::py_schema_err; +use crate::errors::AsErrorValue; use crate::errors::{ErrorType, PydanticCustomError, PydanticKnownError, ValError, ValResult}; use crate::input::Input; use crate::tools::SchemaDict; @@ -49,7 +50,7 @@ impl CustomError { } } - pub fn as_val_error<'a>(&self, input: &'a impl Input<'a>) -> ValError<'a> { + pub fn as_val_error(&self, input: &impl AsErrorValue) -> ValError { match self { CustomError::KnownError(ref known_error) => known_error.clone().into_val_error(input), CustomError::Custom(ref custom_error) => custom_error.clone().into_val_error(input), @@ -93,7 +94,7 @@ impl Validator for CustomErrorValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { self.validator .validate(py, input, state) .map_err(|_| self.custom_error.as_val_error(input)) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 4140b1b26..dbbfc5ee2 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -144,7 +144,7 @@ impl Validator for DataclassArgsValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let args = input.validate_dataclass_args(&self.dataclass_name)?; let output_dict = PyDict::new(py); @@ -202,8 +202,7 @@ impl Validator for DataclassArgsValidator { ErrorTypeDefaults::MultipleArgumentValues, kw_value, field.name.clone(), - ) - .into_owned(py), + ), ); } // found a positional argument, validate it @@ -226,10 +225,9 @@ impl Validator for DataclassArgsValidator { errors.extend(line_errors.into_iter().map(|err| { lookup_path .apply_error_loc(err, self.loc_by_alias, &field.name) - .into_owned(py) })); } - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), } } // found neither, check if there is a default value, otherwise error @@ -294,8 +292,7 @@ impl Validator for DataclassArgsValidator { ErrorTypeDefaults::UnexpectedKeywordArgument, value, raw_key.as_loc_item(), - ) - .into_owned(py), + ), ); } ExtraBehavior::Ignore => {} @@ -375,7 +372,7 @@ impl Validator for DataclassArgsValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let dict: &PyDict = obj.downcast()?; let ok = |output: PyObject| { @@ -518,7 +515,7 @@ impl Validator for DataclassValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if let Some(self_instance) = state.extra().self_instance { // in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__` return self.validate_init(py, self_instance, input, state); @@ -560,7 +557,7 @@ impl Validator for DataclassValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if self.frozen { return Err(ValError::new(ErrorTypeDefaults::FrozenInstance, field_value)); } @@ -600,7 +597,7 @@ impl DataclassValidator { self_instance: &'s PyAny, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { // we need to set `self_instance` to None for nested validators as we don't want to operate on the self_instance // instance anymore let state = &mut state.rebind_extra(|extra| extra.self_instance = None); @@ -627,7 +624,7 @@ impl DataclassValidator { dc: &PyAny, val_output: PyObject, input: &'data impl Input<'data>, - ) -> ValResult<'data, ()> { + ) -> ValResult<()> { let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?; if self.slots { let dc_dict: &PyDict = dc_dict.downcast()?; diff --git a/src/validators/date.rs b/src/validators/date.rs index 7c79101f4..2329dcf2a 100644 --- a/src/validators/date.rs +++ b/src/validators/date.rs @@ -44,7 +44,7 @@ impl Validator for DateValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let strict = state.strict_or(self.strict); let date = match input.validate_date(strict) { Ok(val_match) => val_match.unpack(state), @@ -109,7 +109,7 @@ impl Validator for DateValidator { /// "exact date", e.g. has a zero time component. /// /// Ok(None) means that this is not relevant to dates (the input was not a datetime nor a string) -fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result>, ValError<'data>> { +fn date_from_datetime<'data>(input: &'data impl Input<'data>) -> Result>, ValError> { let either_dt = match input.validate_datetime(false, speedate::MicrosecondsPrecisionOverflowBehavior::Truncate) { Ok(val_match) => val_match.into_inner(), // if the error was a parsing error, update the error type from DatetimeParsing to DateFromDatetimeParsing diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index edbd399e7..156fad699 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -63,7 +63,7 @@ impl Validator for DateTimeValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let strict = state.strict_or(self.strict); let datetime = input .validate_datetime(strict, self.microseconds_precision)? @@ -263,7 +263,7 @@ impl TZConstraint { } } - pub(super) fn tz_check<'d>(&self, tz_offset: Option, input: &'d impl Input<'d>) -> ValResult<'d, ()> { + pub(super) fn tz_check<'d>(&self, tz_offset: Option, input: &'d impl Input<'d>) -> ValResult<()> { match (self, tz_offset) { (TZConstraint::Aware(_), None) => return Err(ValError::new(ErrorTypeDefaults::TimezoneAware, input)), (TZConstraint::Aware(Some(tz_expected)), Some(tz_actual)) => { diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index b9435f046..85f81a38e 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -83,11 +83,7 @@ impl_py_gc_traverse!(DecimalValidator { gt }); -fn extract_decimal_digits_info<'data>( - decimal: &PyAny, - normalized: bool, - py: Python<'data>, -) -> ValResult<'data, (u64, u64)> { +fn extract_decimal_digits_info(decimal: &PyAny, normalized: bool, py: Python<'_>) -> ValResult<(u64, u64)> { let mut normalized_decimal: Option<&PyAny> = None; if normalized { normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal)); @@ -124,7 +120,7 @@ impl Validator for DecimalValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let decimal = input.validate_decimal(state.strict_or(self.strict), py)?; if !self.allow_inf_nan || self.check_digits { @@ -269,11 +265,7 @@ impl Validator for DecimalValidator { } } -pub(crate) fn create_decimal<'a>( - arg: &'a PyAny, - input: &'a impl Input<'a>, - py: Python<'a>, -) -> ValResult<'a, &'a PyAny> { +pub(crate) fn create_decimal<'a>(arg: &'a PyAny, input: &'a impl Input<'a>, py: Python<'a>) -> ValResult<&'a PyAny> { let decimal_type_obj: Py = get_decimal_type(py); decimal_type_obj .call1(py, (arg,)) @@ -293,10 +285,10 @@ pub(crate) fn create_decimal<'a>( fn handle_decimal_new_error<'a>( py: Python<'a>, - input: InputValue<'a>, + input: InputValue, error: PyErr, decimal_exception: &'a PyAny, -) -> ValError<'a> { +) -> ValError { if error.matches(py, decimal_exception) { ValError::new_custom_input(ErrorTypeDefaults::DecimalParsing, input) } else if error.matches(py, PyTypeError::type_object(py)) { diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 979278bb9..7297bd27a 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -72,7 +72,7 @@ impl Validator for DefinitionRefValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let validator = self.definition.get().unwrap(); if let Some(id) = input.identity() { if state.recursion_guard.contains_or_insert(id, self.definition.id()) { @@ -99,7 +99,7 @@ impl Validator for DefinitionRefValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let validator = self.definition.get().unwrap(); if let Some(id) = obj.identity() { if state.recursion_guard.contains_or_insert(id, self.definition.id()) { diff --git a/src/validators/dict.rs b/src/validators/dict.rs index 3ac284b2a..d6507e8e9 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -72,7 +72,7 @@ impl Validator for DictValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let strict = state.strict_or(self.strict); let dict = input.validate_dict(strict)?; match dict { @@ -103,9 +103,9 @@ impl DictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - mapping_iter: impl Iterator>, + mapping_iter: impl Iterator>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let output = PyDict::new(py); let mut errors: Vec = Vec::new(); @@ -113,34 +113,31 @@ impl DictValidator { let value_validator = self.value_validator.as_ref(); for item_result in mapping_iter { let (key, value) = item_result?; - let key = key.borrow_input(); - let value = value.borrow_input(); - let output_key = match key_validator.validate(py, key, state) { + let output_key = match key_validator.validate(py, key.borrow_input(), state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { // these are added in reverse order so [key] is shunted along by the second call errors.push( err.with_outer_location("[key]".into()) - .with_outer_location(key.as_loc_item()) - .into_owned(py), + .with_outer_location(key.as_loc_item()), ); } None } Err(ValError::Omit) => continue, - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), }; - let output_value = match value_validator.validate(py, value, state) { + let output_value = match value_validator.validate(py, value.borrow_input(), state) { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(key.as_loc_item()).into_owned(py)); + errors.push(err.with_outer_location(key.as_loc_item())); } None } Err(ValError::Omit) => continue, - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), }; if let (Some(key), Some(value)) = (output_key, output_value) { output.set_item(key, value)?; diff --git a/src/validators/float.rs b/src/validators/float.rs index 126e539c2..9bb293964 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -69,7 +69,7 @@ impl Validator for FloatValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); if !self.allow_inf_nan && !either_float.as_f64().is_finite() { return Err(ValError::new(ErrorTypeDefaults::FiniteNumber, input)); @@ -101,7 +101,7 @@ impl Validator for ConstrainedFloatValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let either_float = input.validate_float(state.strict_or(self.strict))?.unpack(state); let float: f64 = either_float.as_f64(); if !self.allow_inf_nan && !float.is_finite() { diff --git a/src/validators/frozenset.rs b/src/validators/frozenset.rs index 190a8672d..ceb60eb9d 100644 --- a/src/validators/frozenset.rs +++ b/src/validators/frozenset.rs @@ -33,7 +33,7 @@ impl Validator for FrozenSetValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let collection = input.validate_frozenset(state.strict_or(self.strict))?; let exactness = match &collection { GenericIterable::FrozenSet(_) => Exactness::Exact, diff --git a/src/validators/function.rs b/src/validators/function.rs index 4c5ad9c29..e7134ab29 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -94,7 +94,7 @@ macro_rules! impl_validator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState<'_>, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let validate = |v, s: &mut ValidationState<'_>| self.validator.validate(py, v, s); self._validate(validate, py, input, state) } @@ -105,7 +105,7 @@ macro_rules! impl_validator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let validate = move |v, s: &mut ValidationState<'_>| { self.validator .validate_assignment(py, v, field_name, field_value, s) @@ -135,11 +135,11 @@ impl_build!(FunctionBeforeValidator, "function-before"); impl FunctionBeforeValidator { fn _validate<'s, 'data>( &'s self, - call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, + call: impl FnOnce(&'data PyAny, &mut ValidationState<'_>) -> ValResult, py: Python<'data>, input: &'data impl Input<'data>, state: &'s mut ValidationState<'_>, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let r = if self.info_arg { let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), info)) @@ -168,11 +168,11 @@ impl_build!(FunctionAfterValidator, "function-after"); impl FunctionAfterValidator { fn _validate<'s, 'data, I: Input<'data>>( &'s self, - call: impl FnOnce(&'data I, &mut ValidationState<'_>) -> ValResult<'data, PyObject>, + call: impl FnOnce(&'data I, &mut ValidationState<'_>) -> ValResult, py: Python<'data>, input: &'data I, state: &mut ValidationState<'_>, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let v = call(input, state)?; let r = if self.info_arg { let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); @@ -230,7 +230,7 @@ impl Validator for FunctionPlainValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let r = if self.info_arg { let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), info)) @@ -294,7 +294,7 @@ impl FunctionWrapValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let r = if self.info_arg { let info = ValidationInfo::new(py, state.extra(), &self.config, self.field_name.clone()); self.func.call1(py, (input.to_object(py), handler, info)) @@ -317,7 +317,7 @@ impl Validator for FunctionWrapValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let handler = ValidatorCallable { validator: InternalValidator::new( py, @@ -341,7 +341,7 @@ impl Validator for FunctionWrapValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let handler = AssignmentValidatorCallable { validator: InternalValidator::new( py, @@ -438,7 +438,7 @@ macro_rules! py_err_string { /// Only `ValueError` (including `PydanticCustomError` and `ValidationError`) and `AssertionError` are considered /// as validation errors, `TypeError` is now considered as a runtime error to catch errors in function signatures -pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> ValError<'a> { +pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> ValError { if err.is_instance_of::(py) { let error_value = err.value(py); if let Ok(pydantic_value_error) = error_value.extract::() { @@ -446,7 +446,7 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) -> } else if let Ok(pydantic_error_type) = error_value.extract::() { pydantic_error_type.into_val_error(input) } else if let Ok(validation_error) = err.value(py).extract::() { - validation_error.into_val_error(py) + validation_error.into_val_error() } else { py_err_string!(py, err, error_value, ValueError, input) } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 94497d228..3b5fedd97 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -64,7 +64,7 @@ impl Validator for GeneratorValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let iterator = input.validate_iter()?; let validator = self.item_validator.as_ref().map(|v| { InternalValidator::new( diff --git a/src/validators/int.rs b/src/validators/int.rs index dabfb5115..a72afcb0e 100644 --- a/src/validators/int.rs +++ b/src/validators/int.rs @@ -48,7 +48,7 @@ impl Validator for IntValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { input .validate_int(state.strict_or(self.strict)) .map(|val_match| val_match.unpack(state).into_py(py)) @@ -77,7 +77,7 @@ impl Validator for ConstrainedIntValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let either_int = input.validate_int(state.strict_or(self.strict))?.unpack(state); let int_value = either_int.as_int()?; diff --git a/src/validators/is_instance.rs b/src/validators/is_instance.rs index 189589d6a..8355710d9 100644 --- a/src/validators/is_instance.rs +++ b/src/validators/is_instance.rs @@ -61,7 +61,7 @@ impl Validator for IsInstanceValidator { py: Python<'data>, input: &'data impl Input<'data>, _state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if !input.is_python() { return Err(ValError::InternalErr(PyNotImplementedError::new_err( "Cannot check isinstance when validating from json, \ diff --git a/src/validators/is_subclass.rs b/src/validators/is_subclass.rs index 7a89ef36c..e20a5c291 100644 --- a/src/validators/is_subclass.rs +++ b/src/validators/is_subclass.rs @@ -48,7 +48,7 @@ impl Validator for IsSubclassValidator { py: Python<'data>, input: &'data impl Input<'data>, _state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match input.input_is_subclass(self.class.as_ref(py))? { true => Ok(input.to_object(py)), false => Err(ValError::new( diff --git a/src/validators/json.rs b/src/validators/json.rs index 9dfb5fae2..7ce1f9e29 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -49,12 +49,12 @@ impl Validator for JsonValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let json_value = input.parse_json()?; match self.validator { Some(ref validator) => match validator.validate(py, &json_value, state) { Ok(v) => Ok(v), - Err(err) => Err(err.into_owned(py)), + Err(err) => Err(err), }, None => Ok(json_value.to_object(py)), } diff --git a/src/validators/json_or_python.rs b/src/validators/json_or_python.rs index 302cbdaf6..7cdb10adf 100644 --- a/src/validators/json_or_python.rs +++ b/src/validators/json_or_python.rs @@ -54,7 +54,7 @@ impl Validator for JsonOrPython { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match state.extra().input_type { InputType::Python => self.python.validate(py, input, state), _ => self.json.validate(py, input, state), diff --git a/src/validators/lax_or_strict.rs b/src/validators/lax_or_strict.rs index 78021cd4c..c130cc83a 100644 --- a/src/validators/lax_or_strict.rs +++ b/src/validators/lax_or_strict.rs @@ -61,7 +61,7 @@ impl Validator for LaxOrStrictValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if state.strict_or(self.strict) { self.strict_validator.validate(py, input, state) } else { diff --git a/src/validators/list.rs b/src/validators/list.rs index b2e0ff116..1a8ecb134 100644 --- a/src/validators/list.rs +++ b/src/validators/list.rs @@ -121,7 +121,7 @@ impl Validator for ListValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let seq = input.validate_list(state.strict_or(self.strict))?; let exactness = match &seq { GenericIterable::List(_) | GenericIterable::JsonArray(_) => Exactness::Exact, diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 686920cca..0f1caf601 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -95,7 +95,7 @@ impl LiteralLookup { &self, py: Python<'data>, input: &'data I, - ) -> ValResult<'data, Option<(&'data I, &T)>> { + ) -> ValResult> { if let Some(expected_bool) = &self.expected_bool { if let Ok(bool_value) = input.validate_bool(true) { if bool_value.into_inner() { @@ -195,7 +195,7 @@ impl Validator for LiteralValidator { py: Python<'data>, input: &'data impl Input<'data>, _state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match self.lookup.validate(py, input)? { Some((_, v)) => Ok(v.clone()), None => Err(ValError::new( diff --git a/src/validators/mod.rs b/src/validators/mod.rs index f541ea45d..ee95fa799 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -330,7 +330,7 @@ impl SchemaValidator { context: Option<&'data PyAny>, self_instance: Option<&PyAny>, recursion_guard: &'data mut RecursionGuard, - ) -> ValResult<'data, PyObject> + ) -> ValResult where 's: 'data, { @@ -701,15 +701,15 @@ pub trait Validator: Send + Sync + Debug { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject>; + ) -> ValResult; /// Get a default value, currently only used by `WithDefaultValidator` - fn default_value<'data>( + fn default_value( &self, - _py: Python<'data>, + _py: Python<'_>, _outer_loc: Option>, _state: &mut ValidationState, - ) -> ValResult<'data, Option> { + ) -> ValResult> { Ok(None) } @@ -722,7 +722,7 @@ pub trait Validator: Send + Sync + Debug { _field_name: &'data str, _field_value: &'data PyAny, _state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let py_err = PyTypeError::new_err(format!("validate_assignment is not supported for {}", self.get_name())); Err(py_err.into()) } diff --git a/src/validators/model.rs b/src/validators/model.rs index a571ccd9e..6f810d10e 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -109,7 +109,7 @@ impl Validator for ModelValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if let Some(self_instance) = state.extra().self_instance { // in the case that self_instance is Some, we're calling validation from within `BaseModel.__init__` return self.validate_init(py, self_instance, input, state); @@ -157,7 +157,7 @@ impl Validator for ModelValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if self.frozen { return Err(ValError::new(ErrorTypeDefaults::FrozenInstance, field_value)); } else if self.root_model { @@ -222,7 +222,7 @@ impl ModelValidator { self_instance: &'s PyAny, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { // we need to set `self_instance` to None for nested validators as we don't want to operate on self_instance // anymore let state = &mut state.rebind_extra(|extra| extra.self_instance = None); @@ -249,7 +249,7 @@ impl ModelValidator { input: &'data impl Input<'data>, existing_fields_set: Option<&'data PyAny>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if self.custom_init { // If we wanted, we could introspect the __init__ signature, and store the // keyword arguments and types, and create a validator for them. @@ -291,7 +291,7 @@ impl ModelValidator { instance: PyObject, input: &'data impl Input<'data>, extra: &Extra, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if let Some(ref post_init) = self.post_init { instance .call_method1(py, post_init.as_ref(py), (extra.context,)) diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index a284bd4e9..f603c7b22 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -123,7 +123,7 @@ impl Validator for ModelFieldsValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let strict = state.strict_or(self.strict); let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes); @@ -205,11 +205,11 @@ impl Validator for ModelFieldsValidator { for err in line_errors { errors.push( lookup_path.apply_error_loc(err, self.loc_by_alias, &field.name) - .into_owned(py) + ); } } - Err(err) => return ControlFlow::Break(err.into_owned(py)), + Err(err) => return ControlFlow::Break(err), } continue; } @@ -258,14 +258,14 @@ impl Validator for ModelFieldsValidator { errors.push( err.with_outer_location(raw_key.as_loc_item()) .with_type(ErrorTypeDefaults::InvalidKey) - .into_owned(py) + ); } continue; } - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), }; - let cow = either_str.as_cow().map_err(|err| err.into_owned(py))?; + let cow = either_str.as_cow().map_err(|err| err)?; if used_keys.contains(cow.as_ref()) { continue; } @@ -280,7 +280,7 @@ impl Validator for ModelFieldsValidator { value, raw_key.as_loc_item(), ) - .into_owned(py) + ); } ExtraBehavior::Ignore => {} @@ -294,10 +294,10 @@ impl Validator for ModelFieldsValidator { } Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(raw_key.as_loc_item()).into_owned(py)); + errors.push(err.with_outer_location(raw_key.as_loc_item())); } } - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), } } else { model_extra_dict.set_item(py_key, value.to_object(py))?; @@ -342,7 +342,7 @@ impl Validator for ModelFieldsValidator { field_name: &'data str, field_value: &'data PyAny, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let dict: &PyDict = obj.downcast()?; let get_updated_dict = |output: PyObject| { @@ -350,7 +350,7 @@ impl Validator for ModelFieldsValidator { Ok(dict) }; - let prepare_result = |result: ValResult<'data, PyObject>| match result { + let prepare_result = |result: ValResult| match result { Ok(output) => get_updated_dict(output), Err(ValError::LineErrors(line_errors)) => { let errors = line_errors diff --git a/src/validators/none.rs b/src/validators/none.rs index f6891292b..c014c1218 100644 --- a/src/validators/none.rs +++ b/src/validators/none.rs @@ -29,7 +29,7 @@ impl Validator for NoneValidator { py: Python<'data>, input: &'data impl Input<'data>, _state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match input.is_none() { true => Ok(py.None()), false => Err(ValError::new(ErrorTypeDefaults::NoneRequired, input)), diff --git a/src/validators/nullable.rs b/src/validators/nullable.rs index 85fbd6c26..7e5c68d03 100644 --- a/src/validators/nullable.rs +++ b/src/validators/nullable.rs @@ -38,7 +38,7 @@ impl Validator for NullableValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match input.is_none() { true => Ok(py.None()), false => self.validator.validate(py, input, state), diff --git a/src/validators/set.rs b/src/validators/set.rs index d29c60c3f..38f21a026 100644 --- a/src/validators/set.rs +++ b/src/validators/set.rs @@ -63,7 +63,7 @@ impl Validator for SetValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let collection = input.validate_set(state.strict_or(self.strict))?; let exactness = match &collection { GenericIterable::Set(_) => Exactness::Exact, diff --git a/src/validators/string.rs b/src/validators/string.rs index 98d8a9d99..a8d66a42e 100644 --- a/src/validators/string.rs +++ b/src/validators/string.rs @@ -46,7 +46,7 @@ impl Validator for StrValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { input .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str) .map(|val_match| val_match.unpack(state).into_py(py)) @@ -78,7 +78,7 @@ impl Validator for StrConstrainedValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let either_str = input .validate_str(state.strict_or(self.strict), self.coerce_numbers_to_str)? .unpack(state); diff --git a/src/validators/time.rs b/src/validators/time.rs index abf82091f..687e01eaf 100644 --- a/src/validators/time.rs +++ b/src/validators/time.rs @@ -45,7 +45,7 @@ impl Validator for TimeValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let time = input .validate_time(state.strict_or(self.strict), self.microseconds_precision)? .unpack(state); diff --git a/src/validators/timedelta.rs b/src/validators/timedelta.rs index f04fef91c..326a279e8 100644 --- a/src/validators/timedelta.rs +++ b/src/validators/timedelta.rs @@ -70,7 +70,7 @@ impl Validator for TimeDeltaValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let timedelta = input .validate_timedelta(state.strict_or(self.strict), self.microseconds_precision)? .unpack(state); diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index 9513582e5..af134c8ee 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -50,7 +50,7 @@ impl Validator for TupleVariableValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let seq = input.validate_tuple(state.strict_or(self.strict))?; let exactness = match &seq { GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, @@ -118,12 +118,12 @@ fn validate_tuple_positional<'s, 'data, T: Iterator>, input: &'data impl Input<'data>, state: &mut ValidationState, output: &mut Vec, - errors: &mut Vec>, + errors: &mut Vec, extras_validator: &Option>, items_validators: &[CombinedValidator], collection_iter: &mut T, actual_length: Option, -) -> ValResult<'data, ()> { +) -> ValResult<()> { for (index, validator) in items_validators.iter().enumerate() { match collection_iter.next() { Some(result) => match validator.validate(py, result?, state) { @@ -186,7 +186,7 @@ impl Validator for TuplePositionalValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let collection = input.validate_tuple(state.strict_or(self.strict))?; let exactness: crate::validators::Exactness = match &collection { GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index f55b7d717..95cd19133 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -148,7 +148,7 @@ impl Validator for TypedDictValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let strict = state.strict_or(self.strict); let dict = input.validate_dict(strict)?; @@ -205,11 +205,10 @@ impl Validator for TypedDictValidator { errors.push( lookup_path .apply_error_loc(err, self.loc_by_alias, &field.name) - .into_owned(py) ); } } - Err(err) => return ControlFlow::Break(err.into_owned(py)), + Err(err) => return ControlFlow::Break(err), } continue; } @@ -259,14 +258,14 @@ impl Validator for TypedDictValidator { errors.push( err.with_outer_location(raw_key.as_loc_item()) .with_type(ErrorTypeDefaults::InvalidKey) - .into_owned(py) + ); } continue; } - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), }; - let cow = either_str.as_cow().map_err(|err| err.into_owned(py))?; + let cow = either_str.as_cow().map_err(|err| err)?; if used_keys.contains(cow.as_ref()) { continue; } @@ -281,7 +280,7 @@ impl Validator for TypedDictValidator { value, raw_key.as_loc_item(), ) - .into_owned(py) + ); } ExtraBehavior::Ignore => {} @@ -297,11 +296,11 @@ impl Validator for TypedDictValidator { errors.push( err .with_outer_location(raw_key.as_loc_item()) - .into_owned(py) + ); } } - Err(err) => return Err(err.into_owned(py)), + Err(err) => return Err(err), } } else { output_dict.set_item(py_key, value.to_object(py))?; diff --git a/src/validators/union.rs b/src/validators/union.rs index 0f8fded07..0148e1287 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -106,7 +106,7 @@ impl UnionValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let old_exactness = state.exactness; let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); @@ -172,7 +172,7 @@ impl UnionValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let mut errors = MaybeErrors::new(self.custom_error.as_ref()); let mut rebound_state; @@ -207,7 +207,7 @@ impl Validator for UnionValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match self.mode { UnionMode::Smart => self.validate_smart(py, input, state), UnionMode::LeftToRight => self.validate_left_to_right(py, input, state), @@ -219,18 +219,18 @@ impl Validator for UnionValidator { } } -struct ChoiceLineErrors<'a, 'data> { +struct ChoiceLineErrors<'a> { choice: &'a CombinedValidator, label: Option<&'a str>, - line_errors: Vec>, + line_errors: Vec, } -enum MaybeErrors<'a, 'data> { +enum MaybeErrors<'a> { Custom(&'a CustomError), - Errors(SmallVec<[ChoiceLineErrors<'a, 'data>; 4]>), + Errors(SmallVec<[ChoiceLineErrors<'a>; 4]>), } -impl<'a, 'data> MaybeErrors<'a, 'data> { +impl<'a> MaybeErrors<'a> { fn new(custom_error: Option<&'a CustomError>) -> Self { match custom_error { Some(custom_error) => Self::Custom(custom_error), @@ -238,7 +238,7 @@ impl<'a, 'data> MaybeErrors<'a, 'data> { } } - fn push(&mut self, choice: &'a CombinedValidator, label: Option<&'a str>, line_errors: Vec>) { + fn push(&mut self, choice: &'a CombinedValidator, label: Option<&'a str>, line_errors: Vec) { match self { Self::Custom(_) => {} Self::Errors(errors) => errors.push(ChoiceLineErrors { @@ -249,7 +249,7 @@ impl<'a, 'data> MaybeErrors<'a, 'data> { } } - fn into_val_error(self, input: &'data impl Input<'data>) -> ValError<'data> { + fn into_val_error<'i>(self, input: &impl Input<'i>) -> ValError { match self { Self::Custom(custom_error) => custom_error.as_val_error(input), Self::Errors(errors) => ValError::LineErrors( @@ -395,7 +395,7 @@ impl Validator for TaggedUnionValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { match self.discriminator { Discriminator::LookupKey(ref lookup_key) => { macro_rules! find_validator { @@ -445,7 +445,7 @@ impl TaggedUnionValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - ) -> ValResult<'data, &'data PyString> { + ) -> ValResult<&'data PyString> { let dict = input.strict_dict()?; let either_tag = match dict { GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "type"))? { @@ -492,7 +492,7 @@ impl TaggedUnionValidator { tag: &'data PyAny, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { return match validator.validate(py, input, state) { Ok(res) => Ok(res), @@ -513,7 +513,7 @@ impl TaggedUnionValidator { } } - fn tag_not_found<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValError<'data> { + fn tag_not_found<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValError { match self.custom_error { Some(ref custom_error) => custom_error.as_val_error(input), None => ValError::new( diff --git a/src/validators/url.rs b/src/validators/url.rs index 77f887b7d..ba0d4cd4d 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -65,7 +65,7 @@ impl Validator for UrlValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let mut lib_url = self.get_url(input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { @@ -93,7 +93,7 @@ impl Validator for UrlValidator { state.floor_exactness(Exactness::Lax); Ok(PyUrl::new(lib_url).into_py(py)) } - Err(error_type) => return Err(ValError::new(error_type, input)), + Err(error_type) => Err(ValError::new(error_type, input)), } } @@ -103,7 +103,7 @@ impl Validator for UrlValidator { } impl UrlValidator { - fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, Url> { + fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult { match input.validate_str(strict, false) { Ok(val_match) => { let either_str = val_match.into_inner(); @@ -133,7 +133,7 @@ impl UrlValidator { } } - fn check_length<'s, 'data>(&self, input: &'data impl Input<'data>, url_str: &str) -> ValResult<'data, ()> { + fn check_length<'s, 'data>(&self, input: &'data impl Input<'data>, url_str: &str) -> ValResult<()> { if let Some(max_length) = self.max_length { if url_str.len() > max_length { return Err(ValError::new( @@ -199,7 +199,7 @@ impl Validator for MultiHostUrlValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let mut multi_url = self.get_url(input, state.strict_or(self.strict))?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { @@ -226,7 +226,7 @@ impl Validator for MultiHostUrlValidator { state.floor_exactness(Exactness::Lax); Ok(multi_url.into_py(py)) } - Err(error_type) => return Err(ValError::new(error_type, input)), + Err(error_type) => Err(ValError::new(error_type, input)), } } @@ -236,7 +236,7 @@ impl Validator for MultiHostUrlValidator { } impl MultiHostUrlValidator { - fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult<'data, PyMultiHostUrl> { + fn get_url<'s, 'data>(&'s self, input: &'data impl Input<'data>, strict: bool) -> ValResult { match input.validate_str(strict, false) { Ok(val_match) => { let either_str = val_match.into_inner(); @@ -264,7 +264,7 @@ impl MultiHostUrlValidator { } } - fn check_length<'s, 'data, F>(&self, input: &'data impl Input<'data>, func: F) -> ValResult<'data, ()> + fn check_length<'s, 'data, F>(&self, input: &'data impl Input<'data>, func: F) -> ValResult<()> where F: FnOnce() -> usize, { @@ -287,7 +287,7 @@ fn parse_multihost_url<'url, 'input>( url_str: &'url str, input: &'input impl Input<'input>, strict: bool, -) -> ValResult<'input, PyMultiHostUrl> { +) -> ValResult { macro_rules! parsing_err { ($parse_error:expr) => { Err(ValError::new( @@ -402,11 +402,7 @@ fn parse_multihost_url<'url, 'input>( } } -fn parse_url<'url, 'input>( - url_str: &'url str, - input: &'input impl Input<'input>, - strict: bool, -) -> ValResult<'input, Url> { +fn parse_url<'url, 'input>(url_str: &'url str, input: &'input impl Input<'input>, strict: bool) -> ValResult { if url_str.is_empty() { return Err(ValError::new( ErrorType::UrlParsing { diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 3324d5ea1..4cfd7a272 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -90,7 +90,7 @@ impl Validator for UuidValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { let class = get_uuid_type(py)?; if let Some(py_input) = input.input_is_instance(class) { if let Some(expected_version) = self.version { @@ -135,7 +135,7 @@ impl Validator for UuidValidator { } impl UuidValidator { - fn get_uuid<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValResult<'data, Uuid> { + fn get_uuid<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValResult { let uuid = match input.exact_str().ok() { Some(either_string) => { let cow = either_string.as_cow()?; @@ -198,7 +198,7 @@ impl UuidValidator { /// /// This implementation does not use the Python `__init__` function to speed up the process, /// as the `__init__` function in the Python `uuid` module performs extensive checks. - fn create_py_uuid<'py>(&self, py: Python<'py>, py_type: &PyType, uuid: &Uuid) -> ValResult<'py, Py> { + fn create_py_uuid(&self, py: Python<'_>, py_type: &PyType, uuid: &Uuid) -> ValResult> { let class = create_class(py_type)?; let dc = class.as_ref(py); let int = uuid.as_u128(); diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index a06ccd0cd..443c0e271 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -131,7 +131,7 @@ impl Validator for WithDefaultValidator { py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult<'data, PyObject> { + ) -> ValResult { if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { Ok(self.default_value(py, None::, state)?.unwrap()) } else { @@ -149,12 +149,12 @@ impl Validator for WithDefaultValidator { } } - fn default_value<'data>( + fn default_value( &self, - py: Python<'data>, + py: Python<'_>, outer_loc: Option>, state: &mut ValidationState, - ) -> ValResult<'data, Option> { + ) -> ValResult> { match self.default.default_value(py)? { Some(stored_dft) => { let dft: Py = if self.copy_default { From 7634e5e74a26c10d5d2b5c6a39017ea576a5d94d Mon Sep 17 00:00:00 2001 From: Marius Lie Winger <89073985+mariuswinger@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:27:01 +0100 Subject: [PATCH 129/550] Fix typo in is_instance_schema docstring (#1087) --- python/pydantic_core/core_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index fec3b9966..c097d740e 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1196,7 +1196,7 @@ def is_instance_schema( serialization: SerSchema | None = None, ) -> IsInstanceSchema: """ - Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstnace` method, e.g.: + Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstance` method, e.g.: ```py from pydantic_core import SchemaValidator, core_schema From 3d3f40610ce49586c4e865b82db1544dd1ffb061 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 21 Nov 2023 16:30:32 +0000 Subject: [PATCH 130/550] Don't build dummy objects when populating `ObTypeLookup` (#1086) --- src/serializers/ob_type.rs | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index ff43a1065..7493d5abd 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -59,29 +59,26 @@ pub enum IsType { impl ObTypeLookup { fn new(py: Python) -> Self { - let lib_url = url::Url::parse("https://example.com").unwrap(); Self { none: py.None().as_ref(py).get_type_ptr() as usize, - int: 0i32.into_py(py).as_ref(py).get_type_ptr() as usize, - bool: true.into_py(py).as_ref(py).get_type_ptr() as usize, - float: 0f32.into_py(py).as_ref(py).get_type_ptr() as usize, - list: PyList::empty(py).get_type_ptr() as usize, - dict: PyDict::new(py).get_type_ptr() as usize, + int: PyInt::type_object_raw(py) as usize, + bool: PyBool::type_object_raw(py) as usize, + float: PyFloat::type_object_raw(py) as usize, + list: PyList::type_object_raw(py) as usize, + dict: PyDict::type_object_raw(py) as usize, decimal_object: py.import("decimal").unwrap().getattr("Decimal").unwrap().to_object(py), - string: PyString::new(py, "s").get_type_ptr() as usize, - bytes: PyBytes::new(py, b"s").get_type_ptr() as usize, - bytearray: PyByteArray::new(py, b"s").get_type_ptr() as usize, - tuple: PyTuple::empty(py).get_type_ptr() as usize, - set: PySet::empty(py).unwrap().get_type_ptr() as usize, - frozenset: PyFrozenSet::empty(py).unwrap().get_type_ptr() as usize, - datetime: PyDateTime::new(py, 2000, 1, 1, 0, 0, 0, 0, None) - .unwrap() - .get_type_ptr() as usize, - date: PyDate::new(py, 2000, 1, 1).unwrap().get_type_ptr() as usize, - time: PyTime::new(py, 0, 0, 0, 0, None).unwrap().get_type_ptr() as usize, - timedelta: PyDelta::new(py, 0, 0, 0, false).unwrap().get_type_ptr() as usize, - url: PyUrl::new(lib_url.clone()).into_py(py).as_ref(py).get_type_ptr() as usize, - multi_host_url: PyMultiHostUrl::new(lib_url, None).into_py(py).as_ref(py).get_type_ptr() as usize, + string: PyString::type_object_raw(py) as usize, + bytes: PyBytes::type_object_raw(py) as usize, + bytearray: PyByteArray::type_object_raw(py) as usize, + tuple: PyTuple::type_object_raw(py) as usize, + set: PySet::type_object_raw(py) as usize, + frozenset: PyFrozenSet::type_object_raw(py) as usize, + datetime: PyDateTime::type_object_raw(py) as usize, + date: PyDate::type_object_raw(py) as usize, + time: PyTime::type_object_raw(py) as usize, + timedelta: PyDelta::type_object_raw(py) as usize, + url: PyUrl::type_object_raw(py) as usize, + multi_host_url: PyMultiHostUrl::type_object_raw(py) as usize, enum_object: py.import("enum").unwrap().getattr("Enum").unwrap().to_object(py), generator_object: py .import("types") From a7739ec8681bd5fe5d31be37b11cef482a05ccfb Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 21 Nov 2023 16:53:43 +0000 Subject: [PATCH 131/550] Remove needless uses of `PyString::intern` (#1088) --- src/lookup_key.rs | 20 +++++++------------ src/serializers/type_serializers/dataclass.rs | 2 +- src/validators/arguments.rs | 2 +- src/validators/dataclass.rs | 2 +- src/validators/model.rs | 2 +- src/validators/model_fields.rs | 2 +- src/validators/typed_dict.rs | 2 +- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/lookup_key.rs b/src/lookup_key.rs index 21b584690..e145c1f41 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -51,12 +51,6 @@ impl fmt::Display for LookupKey { } } -macro_rules! py_string { - ($py:ident, $str:expr) => { - PyString::intern($py, $str).into() - }; -} - impl LookupKey { pub fn from_py(py: Python, value: &PyAny, alt_alias: Option<&str>) -> PyResult { if let Ok(alias_py) = value.downcast::() { @@ -67,7 +61,7 @@ impl LookupKey { py_key1: alias_py.into_py(py), path1: LookupPath::from_str(py, alias, Some(alias_py)), key2: alt_alias.to_string(), - py_key2: py_string!(py, alt_alias), + py_key2: PyString::new(py, alt_alias).into(), path2: LookupPath::from_str(py, alt_alias, None), }), None => Ok(Self::simple(py, alias, Some(alias_py))), @@ -98,12 +92,12 @@ impl LookupKey { fn simple(py: Python, key: &str, opt_py_key: Option<&PyString>) -> Self { let py_key = match opt_py_key { - Some(py_key) => py_key.into_py(py), - None => py_string!(py, key), + Some(py_key) => py_key, + None => PyString::new(py, key), }; Self::Simple { key: key.to_string(), - py_key, + py_key: py_key.into(), path: LookupPath::from_str(py, key, opt_py_key), } } @@ -348,10 +342,10 @@ impl fmt::Display for LookupPath { impl LookupPath { fn from_str(py: Python, key: &str, py_key: Option<&PyString>) -> Self { let py_key = match py_key { - Some(py_key) => py_key.into_py(py), - None => py_string!(py, key), + Some(py_key) => py_key, + None => PyString::new(py, key), }; - Self(vec![PathItem::S(key.to_string(), py_key)]) + Self(vec![PathItem::S(key.to_string(), py_key.into())]) } fn from_list(obj: &PyAny) -> PyResult { diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 787e267dd..a82643186 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -39,7 +39,7 @@ impl BuildSerializer for DataclassArgsBuilder { let field_info: &PyDict = item.downcast()?; let name: String = field_info.get_as_req(intern!(py, "name"))?; - let key_py: Py = PyString::intern(py, &name).into_py(py); + let key_py: Py = PyString::new(py, &name).into_py(py); if field_info.get_as(intern!(py, "serialization_exclude"))? == Some(true) { fields.insert(name, SerField::new(py, key_py, None, None, true)); diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index 0405d54e5..aaf5afb6b 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -73,7 +73,7 @@ impl BuildValidator for ArgumentsValidator { } None => Some(LookupKey::from_string(py, &name)), }; - kwarg_key = Some(PyString::intern(py, &name).into()); + kwarg_key = Some(PyString::new(py, &name).into()); } let schema: &PyAny = arg.get_as_req(intern!(py, "schema"))?; diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index dbbfc5ee2..986556b66 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -477,7 +477,7 @@ impl BuildValidator for DataclassValidator { let validator = build_validator(sub_schema, config, definitions)?; let post_init = if schema.get_as::(intern!(py, "post_init"))?.unwrap_or(false) { - Some(PyString::intern(py, "__post_init__").into_py(py)) + Some(PyString::new(py, "__post_init__").into_py(py)) } else { None }; diff --git a/src/validators/model.rs b/src/validators/model.rs index 6f810d10e..1831d2fde 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -89,7 +89,7 @@ impl BuildValidator for ModelValidator { class: class.into(), post_init: schema .get_as::<&str>(intern!(py, "post_init"))? - .map(|s| PyString::intern(py, s).into_py(py)), + .map(|s| PyString::new(py, s).into_py(py)), frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false), root_model: schema.get_as(intern!(py, "root_model"))?.unwrap_or(false), diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index f603c7b22..c73a2821d 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -93,7 +93,7 @@ impl BuildValidator for ModelFieldsValidator { fields.push(Field { name: field_name.to_string(), lookup_key, - name_py: PyString::intern(py, field_name).into(), + name_py: PyString::new(py, field_name).into(), validator, frozen: field_info.get_as::(intern!(py, "frozen"))?.unwrap_or(false), }); diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 95cd19133..d2fc00ec7 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -120,7 +120,7 @@ impl BuildValidator for TypedDictValidator { fields.push(TypedDictField { name: field_name.to_string(), lookup_key, - name_py: PyString::intern(py, field_name).into(), + name_py: PyString::new(py, field_name).into(), validator, required, }); From 3840e006bfa7be234f4d1ce20becacbf57231d83 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Tue, 21 Nov 2023 22:23:20 +0000 Subject: [PATCH 132/550] Correct deprecation message for general_after_validator_function (#1090) --- python/pydantic_core/core_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index c097d740e..daed22d48 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3862,7 +3862,7 @@ def field_after_validator_function(function: WithInfoValidatorFunction, field_na @deprecated('`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') def general_after_validator_function(*args, **kwargs): warnings.warn( - '`with_info_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + '`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', DeprecationWarning, ) return with_info_after_validator_function(*args, **kwargs) From 5b63e7a89df6d48ca0b48b4a885c96f7c2cfd413 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 22 Nov 2023 03:29:18 -0600 Subject: [PATCH 133/550] Avoid using `?` with `get_item` to handle unhashable inputs properly (#1089) Co-authored-by: Samuel Colvin --- src/validators/literal.rs | 5 ++++- tests/validators/test_union.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 0f1caf601..9f8c25e0f 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -136,7 +136,10 @@ impl LiteralLookup { } // must be an enum or bytes if let Some(expected_py) = &self.expected_py { - if let Some(v) = expected_py.as_ref(py).get_item(input)? { + // We don't use ? to unpack the result of `get_item` in the next line because unhashable + // inputs will produce a TypeError, which in this case we just want to treat equivalently + // to a failed lookup + if let Ok(Some(v)) = expected_py.as_ref(py).get_item(input) { let id: usize = v.extract().unwrap(); return Ok(Some((input, &self.values[id]))); } diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 503a5f387..42c542000 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -771,3 +771,35 @@ class BinaryEnum(IntEnum): assert validator.validate_python(1) is not BinaryEnum.ONE assert validator.validate_python(BinaryEnum.ZERO) is BinaryEnum.ZERO assert validator.validate_python(BinaryEnum.ONE) is BinaryEnum.ONE + + +def test_model_and_literal_union() -> None: + # see https://github.com/pydantic/pydantic/issues/8183 + class ModelA: + pass + + validator = SchemaValidator( + { + 'type': 'union', + 'choices': [ + { + 'type': 'model', + 'cls': ModelA, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, + }, + }, + }, + {'type': 'literal', 'expected': [True]}, + ], + } + ) + + # validation against Literal[True] fails bc of the unhashable dict + # A ValidationError is raised, not a ValueError, which allows the validation against the union to continue + m = validator.validate_python({'a': 42}) + assert isinstance(m, ModelA) + assert m.a == 42 + assert validator.validate_python(True) is True From c7daf167140460244c9287c557cd20d1d9ca89d8 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 27 Nov 2023 11:16:13 +0000 Subject: [PATCH 134/550] support newest jiter behaviour (#1092) Co-authored-by: David Hewitt --- Cargo.lock | 4 +- Cargo.toml | 3 +- python/pydantic_core/_pydantic_core.pyi | 4 +- src/input/input_abstract.rs | 4 -- src/input/input_json.rs | 13 +----- src/input/input_python.rs | 20 +--------- src/input/input_string.rs | 13 +----- src/input/shared.rs | 13 +----- src/lib.rs | 22 +++++----- src/validators/json.rs | 53 +++++++++++++++++++++---- src/validators/mod.rs | 51 ++++++++++++++---------- tests/validators/test_json.py | 18 +++++---- 12 files changed, 105 insertions(+), 113 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 249fdfbf9..2ba82639c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,9 +138,9 @@ checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "jiter" -version = "0.0.4" +version = "0.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b27d419c535bf7b50ad355278b1159cbf0cc8d507ea003d625b17bf0375720b8" +checksum = "e184598fea113663dd78e33a24ad3a1e7ba8ceedf71effb7406b3f2eccb63ed1" dependencies = [ "ahash", "lexical-core", diff --git a/Cargo.toml b/Cargo.toml index 512d98908..20f4b79b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,8 +43,7 @@ base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.5.0" -jiter = {version = "0.0.4", features = ["python"]} -#jiter = {path = "../jiter", features = ["python"]} +jiter = {version = "0.0.5", features = ["python"]} [lib] name = "_pydantic_core" diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 4c99a4d61..382a6c804 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -385,7 +385,7 @@ def to_json( JSON bytes. """ -def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any: +def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any: """ Deserialize JSON data to a Python object. @@ -394,6 +394,8 @@ def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> A Arguments: data: The JSON data to deserialize. allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. + cache_strings: Whether to cache strings to avoid constructing new Python objects, + this should have a significant impact on performance while increasing memory usage slightly. Raises: ValueError: If deserialization fails. diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index ceb0495a9..8229677b6 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,8 +4,6 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; -use jiter::JsonValue; - use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; @@ -89,8 +87,6 @@ pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult>; - fn parse_json(&'a self) -> ValResult; - fn validate_str( &'a self, strict: bool, diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 3e94adad6..195f9caba 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -14,7 +14,7 @@ use super::datetime::{ float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime, }; use super::return_enums::ValidationMatch; -use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int}; +use super::shared::{float_as_int, int_as_bool, str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, JsonArgs, @@ -84,13 +84,6 @@ impl<'a> Input<'a> for JsonValue { } } - fn parse_json(&'a self) -> ValResult { - match self { - JsonValue::Str(s) => JsonValue::parse(s.as_bytes(), true).map_err(|e| map_json_err(self, e)), - _ => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), - } - } - fn exact_str(&'a self) -> ValResult> { match self { JsonValue::Str(s) => Ok(s.as_str().into()), @@ -367,10 +360,6 @@ impl<'a> Input<'a> for String { )) } - fn parse_json(&'a self) -> ValResult { - JsonValue::parse(self.as_bytes(), true).map_err(|e| map_json_err(self, e)) - } - fn validate_str( &'a self, _strict: bool, diff --git a/src/input/input_python.rs b/src/input/input_python.rs index ba42abbb8..5d4d4826c 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -10,7 +10,6 @@ use pyo3::types::{ use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues}; use pyo3::{intern, PyTypeInfo}; -use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; @@ -26,8 +25,7 @@ use super::datetime::{ }; use super::return_enums::ValidationMatch; use super::shared::{ - decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float, - str_as_int, + decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, str_as_bool, str_as_float, str_as_int, }; use super::{ py_string_str, BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, @@ -195,22 +193,6 @@ impl<'a> Input<'a> for PyAny { } } - fn parse_json(&'a self) -> ValResult { - let bytes = if let Ok(py_bytes) = self.downcast::() { - py_bytes.as_bytes() - } else if let Ok(py_str) = self.downcast::() { - let str = py_string_str(py_str)?; - str.as_bytes() - } else if let Ok(py_byte_array) = self.downcast::() { - // Safety: from_slice does not run arbitrary Python code and the GIL is held so the - // bytes array will not be mutated while `JsonValue::parse` is reading it - unsafe { py_byte_array.as_bytes() } - } else { - return Err(ValError::new(ErrorTypeDefaults::JsonType, self)); - }; - JsonValue::parse(bytes, true).map_err(|e| map_json_err(self, e)) - } - fn validate_str( &'a self, strict: bool, diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 290247617..0c2c9a8ca 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -1,7 +1,6 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; -use jiter::JsonValue; use speedate::MicrosecondsPrecisionOverflowBehavior; use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; @@ -12,7 +11,7 @@ use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{map_json_err, str_as_bool, str_as_float}; +use super::shared::{str_as_bool, str_as_float}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, ValidationMatch, @@ -86,16 +85,6 @@ impl<'a> Input<'a> for StringMapping<'a> { } } - fn parse_json(&'a self) -> ValResult { - match self { - Self::String(s) => { - let str = py_string_str(s)?; - JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e)) - } - Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)), - } - } - fn validate_str( &'a self, _strict: bool, diff --git a/src/input/shared.rs b/src/input/shared.rs index 647bce8a3..591c5abfc 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,10 +1,9 @@ use pyo3::sync::GILOnceCell; use pyo3::{intern, Py, PyAny, Python, ToPyObject}; -use jiter::JsonValueError; use num_bigint::BigInt; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; +use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use super::{EitherFloat, EitherInt, Input}; static ENUM_META_OBJECT: GILOnceCell> = GILOnceCell::new(); @@ -20,16 +19,6 @@ pub fn get_enum_meta_object(py: Python) -> Py { .clone() } -pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError { - ValError::new( - ErrorType::JsonInvalid { - error: error.to_string(), - context: None, - }, - input, - ) -} - pub fn str_as_bool<'a>(input: &'a impl Input<'a>, str: &str) -> ValResult { if str == "0" || str.eq_ignore_ascii_case("f") diff --git a/src/lib.rs b/src/lib.rs index f969c0657..de4a6d9bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,6 @@ extern crate core; use std::sync::OnceLock; use pyo3::exceptions::PyTypeError; -use pyo3::types::{PyByteArray, PyBytes, PyString}; use pyo3::{prelude::*, sync::GILOnceCell}; // parse this first to get access to the contained macro @@ -37,17 +36,16 @@ pub use serializers::{ }; pub use validators::{validate_core_schema, PySome, SchemaValidator}; -#[pyfunction(signature = (data, *, allow_inf_nan=true))] -pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult { - if let Ok(py_bytes) = data.downcast::() { - jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan) - } else if let Ok(py_str) = data.downcast::() { - jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan) - } else if let Ok(py_byte_array) = data.downcast::() { - jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan) - } else { - Err(PyTypeError::new_err("Expected bytes, bytearray or str")) - } +use crate::input::Input; + +#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))] +pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool, cache_strings: bool) -> PyResult { + let v_match = data + .validate_bytes(false) + .map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?; + let json_either_bytes = v_match.into_inner(); + let json_bytes = json_either_bytes.as_slice(); + jiter::python_parse(py, json_bytes, allow_inf_nan, cache_strings).map_err(|e| jiter::map_json_error(json_bytes, &e)) } pub fn get_pydantic_core_version() -> &'static str { diff --git a/src/validators/json.rs b/src/validators/json.rs index 7ce1f9e29..44250ef9d 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -2,8 +2,10 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; -use crate::errors::ValResult; -use crate::input::Input; +use jiter::JsonValue; + +use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::input::{EitherBytes, Input, ValidationMatch}; use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -50,13 +52,19 @@ impl Validator for JsonValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult { - let json_value = input.parse_json()?; + let v_match = validate_json_bytes(input)?; + let json_either_bytes = v_match.unpack(state); + let json_bytes = json_either_bytes.as_slice(); match self.validator { - Some(ref validator) => match validator.validate(py, &json_value, state) { - Ok(v) => Ok(v), - Err(err) => Err(err), - }, - None => Ok(json_value.to_object(py)), + Some(ref validator) => { + let json_value = JsonValue::parse(json_bytes, true).map_err(|e| map_json_err(input, e, json_bytes))?; + validator.validate(py, &json_value, state) + } + None => { + let obj = + jiter::python_parse(py, json_bytes, true, true).map_err(|e| map_json_err(input, e, json_bytes))?; + Ok(obj) + } } } @@ -64,3 +72,32 @@ impl Validator for JsonValidator { &self.name } } + +pub fn validate_json_bytes<'data>(input: &'data impl Input<'data>) -> ValResult>> { + match input.validate_bytes(false) { + Ok(v_match) => Ok(v_match), + Err(ValError::LineErrors(e)) => Err(ValError::LineErrors( + e.into_iter().map(map_bytes_error).collect::>(), + )), + Err(e) => Err(e), + } +} + +fn map_bytes_error(line_error: ValLineError) -> ValLineError { + match line_error.error_type { + ErrorType::BytesType { .. } => { + ValLineError::new_custom_input(ErrorTypeDefaults::JsonType, line_error.input_value) + } + _ => line_error, + } +} + +pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: jiter::JsonError, json_bytes: &[u8]) -> ValError { + ValError::new( + ErrorType::JsonInvalid { + error: error.description(json_bytes), + context: None, + }, + input, + ) +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index ee95fa799..7809b1ee3 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -171,7 +171,6 @@ impl SchemaValidator { from_attributes, context, self_instance, - &mut RecursionGuard::default(), ) .map_err(|e| self.prepare_validation_err(py, e, InputType::Python)) } @@ -194,7 +193,6 @@ impl SchemaValidator { from_attributes, context, self_instance, - &mut RecursionGuard::default(), ) { Ok(_) => Ok(true), Err(ValError::InternalErr(err)) => Err(err), @@ -213,22 +211,18 @@ impl SchemaValidator { context: Option<&PyAny>, self_instance: Option<&PyAny>, ) -> PyResult { - let recursion_guard = &mut RecursionGuard::default(); - match input.parse_json() { - Ok(input) => self - ._validate( - py, - &input, - InputType::Json, - strict, - None, - context, - self_instance, - recursion_guard, - ) - .map_err(|e| self.prepare_validation_err(py, e, InputType::Json)), - Err(err) => Err(self.prepare_validation_err(py, err, InputType::Json)), - } + let r = match json::validate_json_bytes(input) { + Ok(v_match) => self._validate_json( + py, + input, + v_match.into_inner().as_slice(), + strict, + context, + self_instance, + ), + Err(err) => Err(err), + }; + r.map_err(|e| self.prepare_validation_err(py, e, InputType::Json)) } #[pyo3(signature = (input, *, strict=None, context=None))] @@ -242,8 +236,7 @@ impl SchemaValidator { let t = InputType::String; let string_mapping = StringMapping::new_value(input).map_err(|e| self.prepare_validation_err(py, e, t))?; - let recursion_guard = &mut RecursionGuard::default(); - match self._validate(py, &string_mapping, t, strict, None, context, None, recursion_guard) { + match self._validate(py, &string_mapping, t, strict, None, context, None) { Ok(r) => Ok(r), Err(e) => Err(self.prepare_validation_err(py, e, t)), } @@ -329,18 +322,32 @@ impl SchemaValidator { from_attributes: Option, context: Option<&'data PyAny>, self_instance: Option<&PyAny>, - recursion_guard: &'data mut RecursionGuard, ) -> ValResult where 's: 'data, { + let mut recursion_guard = RecursionGuard::default(); let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), - recursion_guard, + &mut recursion_guard, ); self.validator.validate(py, input, &mut state) } + fn _validate_json( + &self, + py: Python, + input: &PyAny, + json_data: &[u8], + strict: Option, + context: Option<&PyAny>, + self_instance: Option<&PyAny>, + ) -> ValResult { + let json_value = + jiter::JsonValue::parse(json_data, true).map_err(|e| json::map_json_err(input, e, json_data))?; + self._validate(py, &json_value, InputType::Json, strict, None, context, self_instance) + } + fn prepare_validation_err(&self, py: Python, error: ValError, input_type: InputType) -> PyErr { ValidationError::from_val_error( py, diff --git a/tests/validators/test_json.py b/tests/validators/test_json.py index d8666d335..228d18e55 100644 --- a/tests/validators/test_json.py +++ b/tests/validators/test_json.py @@ -48,36 +48,40 @@ def test_any(py_and_json: PyAndJson, input_value, expected): @pytest.mark.parametrize( 'input_value,expected', [ - ('{"a": 1}', {'a': 1}), - (b'{"a": 1}', {'a': 1}), - ( + pytest.param('{"a": 1}', {'a': 1}, id='str'), + pytest.param(b'{"a": 1}', {'a': 1}, id='bytes'), + pytest.param( '🐈 Hello \ud800World', Err( 'Input should be a valid string, unable to parse raw data as a unicode string ' "[type=string_unicode, input_value='🐈 Hello \\ud800World', input_type=str]" ), + id='str_unicode', ), - (bytearray(b'{"a": 1}'), {'a': 1}), - ( + pytest.param(bytearray(b'{"a": 1}'), {'a': 1}, id='bytearray'), + pytest.param( 'xx', Err( 'Invalid JSON: expected value at line 1 column 1 ' "[type=json_invalid, input_value='xx', input_type=str]" ), + id='str_invalid', ), - ( + pytest.param( b'xx', Err( 'Invalid JSON: expected value at line 1 column 1 ' "[type=json_invalid, input_value=b'xx', input_type=bytes]" ), + id='bytes_invalid', ), - ( + pytest.param( bytearray(b'xx'), Err( 'Invalid JSON: expected value at line 1 column 1 ' "[type=json_invalid, input_value=bytearray(b'xx'), input_type=bytearray]" ), + id='bytearray_invalid', ), ], ) From 3b6344e9a956f3a4df3c5766ea21995a6cd1432b Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 29 Nov 2023 11:16:46 -0600 Subject: [PATCH 135/550] Fixing `exclude_none` for json serialization of `computed_field`s (#1098) --- src/serializers/computed_fields.rs | 4 ++++ tests/serializers/test_model.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 8a1f041ae..6f4553069 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -73,6 +73,10 @@ impl ComputedFields { } for computed_field in &self.0 { let property_name_py = computed_field.property_name_py.as_ref(model.py()); + let value = model.getattr(property_name_py).map_err(py_err_se_err)?; + if extra.exclude_none && value.is_none() { + return Ok(()); + } if let Some((next_include, next_exclude)) = filter .key_filter(property_name_py, include, exclude) .map_err(py_err_se_err)? diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 32ecd3c1a..4249c3015 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -608,7 +608,7 @@ def volume(self) -> int: assert s.to_json(Model(3, 4)) == b'{"width":3,"height":4,"Area":12,"volume":48}' -def test_computed_field_to_python_exclude_none(): +def test_computed_field_exclude_none(): @dataclasses.dataclass class Model: width: int @@ -646,6 +646,8 @@ def volume(self) -> None: 'volume': None, } assert s.to_python(Model(3, 4), mode='json', exclude_none=True) == {'width': 3, 'height': 4, 'Area': 12} + assert s.to_json(Model(3, 4), exclude_none=False) == b'{"width":3,"height":4,"Area":12,"volume":null}' + assert s.to_json(Model(3, 4), exclude_none=True) == b'{"width":3,"height":4,"Area":12}' @pytest.mark.skipif(cached_property is None, reason='cached_property is not available') From 5d64894e8dbec8c21dadf2d79c756dcfe4f436a3 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 29 Nov 2023 17:25:14 +0000 Subject: [PATCH 136/550] bump pandas tests to run on 3.12 (#1097) --- tests/requirements.txt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 66dda075f..67e8e002b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,8 +1,11 @@ coverage==7.2.7 dirty-equals==0.6.0 hypothesis==6.79.4 +# TODO: remove manual override for dateutil once a version newer than 2.8.2 is +# released which removes use of deprecated utcfromtimestamp +git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux -pandas==2.0.3; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' +pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' pytest==7.4.3 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' @@ -13,5 +16,5 @@ pytest-pretty==1.2.0 pytest-timeout==2.2.0 pytz==2023.3.post1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux -numpy==1.25.2; python_version >= "3.9" and python_version < "3.12" and implementation_name == "cpython" and platform_machine == 'x86_64' +numpy==1.26.2; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' exceptiongroup==1.1; python_version < "3.11" From f323e74a50b7951a700bfb526e16d67e75d16072 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 30 Nov 2023 02:45:46 -0700 Subject: [PATCH 137/550] Fix memory leak caused by not visiting the function in a CallValidator during gc (#1100) --- src/validators/call.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/validators/call.rs b/src/validators/call.rs index d43df4984..bf80415af 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -67,6 +67,7 @@ impl BuildValidator for CallValidator { } impl_py_gc_traverse!(CallValidator { + function, arguments_validator, return_validator }); From 7fa450d9666ba37d0c5b48369d844be56598c415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Sandoval?= Date: Thu, 30 Nov 2023 07:10:39 -0300 Subject: [PATCH 138/550] pass extra argument in arguments validator (#1094) --- src/validators/arguments.rs | 19 +++-- tests/validators/test_arguments.py | 114 +++++++++++++++-------------- 2 files changed, 71 insertions(+), 62 deletions(-) diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index aaf5afb6b..a605ff999 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use ahash::AHashSet; use crate::build_tools::py_schema_err; -use crate::build_tools::schema_or_config_same; +use crate::build_tools::{schema_or_config_same, ExtraBehavior}; use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; @@ -31,6 +31,7 @@ pub struct ArgumentsValidator { var_args_validator: Option>, var_kwargs_validator: Option>, loc_by_alias: bool, + extra: ExtraBehavior, } impl BuildValidator for ArgumentsValidator { @@ -119,6 +120,7 @@ impl BuildValidator for ArgumentsValidator { None => None, }, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), + extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?, } .into()) } @@ -307,15 +309,16 @@ impl Validator for ArgumentsValidator { Err(err) => return Err(err), }, None => { - errors.push(ValLineError::new_with_loc( - ErrorTypeDefaults::UnexpectedKeywordArgument, - value, - raw_key.as_loc_item(), - )); + if let ExtraBehavior::Forbid = self.extra { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedKeywordArgument, + value, + raw_key.as_loc_item(), + )); + } } } - } - } + }} } } }}; diff --git a/tests/validators/test_arguments.py b/tests/validators/test_arguments.py index 4cf1d3ad2..45ade32e6 100644 --- a/tests/validators/test_arguments.py +++ b/tests/validators/test_arguments.py @@ -775,57 +775,57 @@ def test_alias_populate_by_name(py_and_json: PyAndJson, input_value, expected): assert v.validate_test(input_value) == expected -def validate(function): - """ - a demo validation decorator to test arguments - """ - parameters = signature(function).parameters - - type_hints = get_type_hints(function) - mode_lookup = { - Parameter.POSITIONAL_ONLY: 'positional_only', - Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', - Parameter.KEYWORD_ONLY: 'keyword_only', - } - - arguments_schema = [] - schema = {'type': 'arguments', 'arguments_schema': arguments_schema} - for i, (name, p) in enumerate(parameters.items()): - if p.annotation is p.empty: - annotation = Any - else: - annotation = type_hints[name] - - assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented' - if annotation in (bool, int, float, str): - arg_schema = {'type': annotation.__name__} - else: - assert annotation is Any - arg_schema = {'type': 'any'} - - if p.kind in mode_lookup: - if p.default is not p.empty: - arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default} - s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema} - arguments_schema.append(s) - elif p.kind == Parameter.VAR_POSITIONAL: - schema['var_args_schema'] = arg_schema - else: - assert p.kind == Parameter.VAR_KEYWORD, p.kind - schema['var_kwargs_schema'] = arg_schema - - validator = SchemaValidator(schema) - - @wraps(function) - def wrapper(*args, **kwargs): - validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs)) - return function(*validated_args, **validated_kwargs) - - return wrapper +def validate(config=None): + def decorator(function): + parameters = signature(function).parameters + type_hints = get_type_hints(function) + mode_lookup = { + Parameter.POSITIONAL_ONLY: 'positional_only', + Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', + Parameter.KEYWORD_ONLY: 'keyword_only', + } + + arguments_schema = [] + schema = {'type': 'arguments', 'arguments_schema': arguments_schema} + for i, (name, p) in enumerate(parameters.items()): + if p.annotation is p.empty: + annotation = Any + else: + annotation = type_hints[name] + + assert annotation in (bool, int, float, str, Any), f'schema for {annotation} not implemented' + if annotation in (bool, int, float, str): + arg_schema = {'type': annotation.__name__} + else: + assert annotation is Any + arg_schema = {'type': 'any'} + + if p.kind in mode_lookup: + if p.default is not p.empty: + arg_schema = {'type': 'default', 'schema': arg_schema, 'default': p.default} + s = {'name': name, 'mode': mode_lookup[p.kind], 'schema': arg_schema} + arguments_schema.append(s) + elif p.kind == Parameter.VAR_POSITIONAL: + schema['var_args_schema'] = arg_schema + else: + assert p.kind == Parameter.VAR_KEYWORD, p.kind + schema['var_kwargs_schema'] = arg_schema + + validator = SchemaValidator(schema, config=config) + + @wraps(function) + def wrapper(*args, **kwargs): + # Validate arguments using the original schema + validated_args, validated_kwargs = validator.validate_python(ArgsKwargs(args, kwargs)) + return function(*validated_args, **validated_kwargs) + + return wrapper + + return decorator def test_function_any(): - @validate + @validate() def foobar(a, b, c): return a, b, c @@ -842,7 +842,7 @@ def foobar(a, b, c): def test_function_types(): - @validate + @validate() def foobar(a: int, b: int, *, c: int): return a, b, c @@ -894,8 +894,8 @@ def test_function_positional_only(import_execute): # language=Python m = import_execute( """ -def create_function(validate): - @validate +def create_function(validate, config = None): + @validate(config = config) def foobar(a: int, b: int, /, c: int): return a, b, c return foobar @@ -915,6 +915,12 @@ def foobar(a: int, b: int, /, c: int): }, {'type': 'unexpected_keyword_argument', 'loc': ('b',), 'msg': 'Unexpected keyword argument', 'input': 2}, ] + # Allowing extras using the config + foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'allow'}) + assert foobar('1', '2', c=3, d=4) == (1, 2, 3) + # Ignore works similar than allow + foobar = m.create_function(validate, config={'title': 'func', 'extra_fields_behavior': 'ignore'}) + assert foobar('1', '2', c=3, d=4) == (1, 2, 3) @pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher') @@ -923,7 +929,7 @@ def test_function_positional_only_default(import_execute): m = import_execute( """ def create_function(validate): - @validate + @validate() def foobar(a: int, b: int = 42, /): return a, b return foobar @@ -940,7 +946,7 @@ def test_function_positional_kwargs(import_execute): m = import_execute( """ def create_function(validate): - @validate + @validate() def foobar(a: int, b: int, /, **kwargs: bool): return a, b, kwargs return foobar @@ -953,7 +959,7 @@ def foobar(a: int, b: int, /, **kwargs: bool): def test_function_args_kwargs(): - @validate + @validate() def foobar(*args, **kwargs): return args, kwargs From a5b390691dff689f9bff590a57420ebed4cd27f3 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 12 Dec 2023 15:55:08 +0300 Subject: [PATCH 139/550] Use input type json when validating a json schema (#1117) Co-authored-by: sydney-runkle --- src/validators/json.rs | 7 +++++-- tests/validators/test_json.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/validators/json.rs b/src/validators/json.rs index 44250ef9d..641892364 100644 --- a/src/validators/json.rs +++ b/src/validators/json.rs @@ -5,7 +5,7 @@ use pyo3::types::PyDict; use jiter::JsonValue; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; -use crate::input::{EitherBytes, Input, ValidationMatch}; +use crate::input::{EitherBytes, Input, InputType, ValidationMatch}; use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -58,7 +58,10 @@ impl Validator for JsonValidator { match self.validator { Some(ref validator) => { let json_value = JsonValue::parse(json_bytes, true).map_err(|e| map_json_err(input, e, json_bytes))?; - validator.validate(py, &json_value, state) + let mut json_state = state.rebind_extra(|e| { + e.input_type = InputType::Json; + }); + validator.validate(py, &json_value, &mut json_state) } None => { let obj = diff --git a/tests/validators/test_json.py b/tests/validators/test_json.py index 228d18e55..9bd553d46 100644 --- a/tests/validators/test_json.py +++ b/tests/validators/test_json.py @@ -1,4 +1,5 @@ import re +from enum import Enum import pytest @@ -152,6 +153,23 @@ def test_dict_key(py_and_json: PyAndJson): ] +def test_enum() -> None: + class MyEnum(Enum): + a = 'a' + b = 'b' + + enum_schema = core_schema.lax_or_strict_schema( + core_schema.no_info_after_validator_function(MyEnum, core_schema.str_schema()), + core_schema.is_instance_schema(MyEnum), + ) + v = core_schema.json_schema(enum_schema) + v = SchemaValidator(v) + assert v.validate_python('"a"') == MyEnum.a + assert v.validate_python('"b"') == MyEnum.b + with pytest.raises(ValidationError): + v.validate_python('"c"') + + def test_any_schema_no_schema(): v = SchemaValidator(core_schema.json_schema()) assert 'validator:None' in plain_repr(v) From dfe5027325d3b2b8f5fdcd1b3a7ee9332deacbdb Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 12 Dec 2023 14:20:39 +0000 Subject: [PATCH 140/550] Implement pickling for ValidationError (#1119) --- src/errors/validation_exception.rs | 14 ++++++++++++++ tests/test_errors.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 91ef60a0d..e77b21974 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -341,6 +341,20 @@ impl ValidationError { fn __str__(&self, py: Python) -> String { self.__repr__(py) } + + fn __reduce__(slf: &PyCell) -> PyResult<(&PyAny, PyObject)> { + let py = slf.py(); + let callable = slf.getattr("from_exception_data")?; + let borrow = slf.try_borrow()?; + let args = ( + borrow.title.as_ref(py), + borrow.errors(py, include_url_env(py), true, true)?, + borrow.input_type.into_py(py), + borrow.hide_input, + ) + .into_py(slf.py()); + Ok((callable, args)) + } } // TODO: is_utf8_char_boundary, floor_char_boundary and ceil_char_boundary diff --git a/tests/test_errors.py b/tests/test_errors.py index 05815aec5..88dcace8f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,4 +1,5 @@ import enum +import pickle import re import sys from decimal import Decimal @@ -1074,3 +1075,17 @@ def test_hide_input_in_json() -> None: for error in exc_info.value.errors(include_input=False): assert 'input' not in error + + +@pytest.mark.skipif( + sys.version_info < (3, 9) and sys.implementation.name == 'pypy', + reason='PyPy before 3.9 cannot pickle this correctly', +) +def test_validation_error_pickle() -> None: + s = SchemaValidator({'type': 'int'}) + with pytest.raises(ValidationError) as exc_info: + s.validate_python('definitely not an int') + + original = exc_info.value + roundtripped = pickle.loads(pickle.dumps(original)) + assert original.errors() == roundtripped.errors() From cfd1efd8ffa1b686b8d25f4ee858edf69ce99cf2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:30:35 +0000 Subject: [PATCH 141/550] Bump uraimo/run-on-arch-action from 2.5.1 to 2.6.0 (#1107) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bb0918f8a..f8d0fa83c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -632,7 +632,7 @@ jobs: name: pypi_files_pgo path: dist - - uses: uraimo/run-on-arch-action@v2.5.1 + - uses: uraimo/run-on-arch-action@v2.6.0 name: install & test with: arch: ${{ matrix.target }} From c7b6bb2f221033f55bf48c8a3859aa773d26c0a3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:30:49 +0000 Subject: [PATCH 142/550] Bump CodSpeedHQ/action from 1 to 2 (#1108) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/codspeed.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index b01d4ff7e..697815a16 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -72,6 +72,6 @@ jobs: RUSTFLAGS: "-Cprofile-use=${{ github.workspace }}/merged.profdata" - name: Run CodSpeed benchmarks - uses: CodSpeedHQ/action@v1 + uses: CodSpeedHQ/action@v2 with: run: pytest tests/benchmarks/ --codspeed From af36330f983093de096e1d3e9f8ca7236ad91201 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:31:03 +0000 Subject: [PATCH 143/550] Bump mymindstorm/setup-emsdk from 12 to 13 (#1106) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f8d0fa83c..b1f6c0236 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -307,7 +307,7 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - - uses: mymindstorm/setup-emsdk@v12 + - uses: mymindstorm/setup-emsdk@v13 with: # NOTE!: as per https://github.com/pydantic/pydantic-core/pull/149 this version needs to match the version # in node_modules/pyodide/repodata.json, to get the version, run: From 167a9f1e80b4d681a9bd1d7390503aa742f93bb7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:31:43 +0000 Subject: [PATCH 144/550] Bump griffe from 0.36.9 to 0.38.0 (#1112) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 38eacfdbc..d2cc86cd9 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -griffe==0.36.9 +griffe==0.38.0 pyright==1.1.334 ruff==0.1.5 mypy==1.6.1 From c1b4febcfec3561e0aa464f673f9734c58c44dda Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:32:19 +0000 Subject: [PATCH 145/550] Bump ruff from 0.1.5 to 0.1.6 (#1109) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index d2cc86cd9..6bcb6fd93 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ griffe==0.38.0 pyright==1.1.334 -ruff==0.1.5 +ruff==0.1.6 mypy==1.6.1 From 77068b7c5212e4ed2c20798cc269c5c7a6d6782a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:32:30 +0000 Subject: [PATCH 146/550] Bump uuid from 1.5.0 to 1.6.1 (#1105) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2ba82639c..349019a9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -629,9 +629,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.5.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" [[package]] name = "version_check" diff --git a/Cargo.toml b/Cargo.toml index 20f4b79b8..50fd5b262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ idna = "0.4.0" base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" -uuid = "1.5.0" +uuid = "1.6.1" jiter = {version = "0.0.5", features = ["python"]} [lib] From 8da73b1aa8a9d0fede10bcb23fe352c37ad199f5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:32:57 +0000 Subject: [PATCH 147/550] Bump url from 2.4.1 to 2.5.0 (#1102) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 26 ++++++++++++++++++-------- Cargo.toml | 2 +- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 349019a9e..4a497db95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,9 +74,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -114,6 +114,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "2.0.0" @@ -306,9 +316,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "proc-macro2" @@ -326,7 +336,7 @@ dependencies = [ "ahash", "base64", "enum_dispatch", - "idna", + "idna 0.4.0", "jiter", "num-bigint", "pyo3", @@ -618,12 +628,12 @@ checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", ] diff --git a/Cargo.toml b/Cargo.toml index 50fd5b262..11b95455f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "1.0.190", features = ["derive"] } speedate = "0.13.0" smallvec = "1.11.1" ahash = "0.8.6" -url = "2.4.1" +url = "2.5.0" # idna is already required by url, added here to be explicit idna = "0.4.0" base64 = "0.21.5" From 1e5a4a22e7d2d8d573a3c507a61ffc2e74eccc29 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:33:24 +0000 Subject: [PATCH 148/550] Bump smallvec from 1.11.1 to 1.11.2 (#1103) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4a497db95..658dc1a3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,9 +525,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "speedate" diff --git a/Cargo.toml b/Cargo.toml index 11b95455f..9868be24b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_ enum_dispatch = "0.3.8" serde = { version = "1.0.190", features = ["derive"] } speedate = "0.13.0" -smallvec = "1.11.1" +smallvec = "1.11.2" ahash = "0.8.6" url = "2.5.0" # idna is already required by url, added here to be explicit From a5e0e0947277e158e368f5a23333f95bc54231f7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:40:03 +0000 Subject: [PATCH 149/550] Bump serde from 1.0.190 to 1.0.193 (#1104) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 658dc1a3a..7a9d1d17e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -493,18 +493,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d3c334ca1ee894a2c6f6ad698fe8c435b76d504b13d436f0685d648d6d96f7" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.190" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67c5609f394e5c2bd7fc51efda478004ea80ef42fee983d5c67a65e34f32c0e3" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 9868be24b..7b61e32bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.190", features = ["derive"] } +serde = { version = "1.0.193", features = ["derive"] } speedate = "0.13.0" smallvec = "1.11.2" ahash = "0.8.6" From 360bd88d74cbd25cc74f5baecdaf241c9ba8510c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:43:04 +0000 Subject: [PATCH 150/550] Bump ruff from 0.1.6 to 0.1.7 (#1121) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 6bcb6fd93..348697f36 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ griffe==0.38.0 pyright==1.1.334 -ruff==0.1.6 +ruff==0.1.7 mypy==1.6.1 From bec63dbae945c6d6492810414ba97b44fdbd2000 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:49:37 +0000 Subject: [PATCH 151/550] Bump pyright from 1.1.334 to 1.1.339 (#1120) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 348697f36..c97e47a1b 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ griffe==0.38.0 -pyright==1.1.334 +pyright==1.1.339 ruff==0.1.7 mypy==1.6.1 From 10ad10f9f1c171e4d8bb5926c6bb2956a1f1e551 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:52:35 -0600 Subject: [PATCH 152/550] Support `yyyy-MM-DD` string for datetimes (#1124) --- python/pydantic_core/core_schema.py | 1 + src/errors/types.rs | 5 +++ src/validators/datetime.rs | 57 +++++++++++++++++++++++++++-- tests/test_errors.py | 1 + tests/test_hypothesis.py | 10 ++--- tests/validators/test_datetime.py | 15 ++++++-- 6 files changed, 77 insertions(+), 12 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index daed22d48..ee305284a 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3787,6 +3787,7 @@ def definition_reference_schema( 'datetime_type', 'datetime_parsing', 'datetime_object_invalid', + 'datetime_from_date_parsing', 'datetime_past', 'datetime_future', 'timezone_naive', diff --git a/src/errors/types.rs b/src/errors/types.rs index 5c3fc1a7c..cfa96221e 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -338,6 +338,9 @@ error_types! { DatetimeObjectInvalid { error: {ctx_type: String, ctx_fn: field_from_context}, }, + DatetimeFromDateParsing { + error: {ctx_type: Cow<'static, str>, ctx_fn: cow_field_from_context}, + }, DatetimePast {}, DatetimeFuture {}, // --------------------- @@ -529,6 +532,7 @@ impl ErrorType { Self::DatetimeType {..} => "Input should be a valid datetime", Self::DatetimeParsing {..} => "Input should be a valid datetime, {error}", Self::DatetimeObjectInvalid {..} => "Invalid datetime object, got {error}", + Self::DatetimeFromDateParsing {..} => "Input should be a valid datetime or date, {error}", Self::DatetimePast {..} => "Input should be in the past", Self::DatetimeFuture {..} => "Input should be in the future", Self::TimezoneNaive {..} => "Input should not have timezone info", @@ -684,6 +688,7 @@ impl ErrorType { Self::DateFromDatetimeParsing { error, .. } => render!(tmpl, error), Self::TimeParsing { error, .. } => render!(tmpl, error), Self::DatetimeParsing { error, .. } => render!(tmpl, error), + Self::DatetimeFromDateParsing { error, .. } => render!(tmpl, error), Self::DatetimeObjectInvalid { error, .. } => render!(tmpl, error), Self::TimezoneOffset { tz_expected, tz_actual, .. diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 156fad699..8779ea76c 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -2,7 +2,7 @@ use pyo3::intern; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDateTime, PyDict, PyString}; -use speedate::DateTime; +use speedate::{DateTime, Time}; use std::cmp::Ordering; use strum::EnumMessage; @@ -13,6 +13,7 @@ use crate::input::{EitherDateTime, Input}; use crate::tools::SchemaDict; +use super::Exactness; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -65,9 +66,15 @@ impl Validator for DateTimeValidator { state: &mut ValidationState, ) -> ValResult { let strict = state.strict_or(self.strict); - let datetime = input - .validate_datetime(strict, self.microseconds_precision)? - .unpack(state); + let datetime = match input.validate_datetime(strict, self.microseconds_precision) { + Ok(val_match) => val_match.unpack(state), + // if the error was a parsing error, in lax mode we allow dates and add the time 00:00:00 + Err(line_errors @ ValError::LineErrors(..)) if !strict => { + state.floor_exactness(Exactness::Lax); + datetime_from_date(input)?.ok_or(line_errors)? + } + Err(otherwise) => return Err(otherwise), + }; if let Some(constraints) = &self.constraints { // if we get an error from as_speedate, it's probably because the input datetime was invalid // specifically had an invalid tzinfo, hence here we return a validation error @@ -132,6 +139,48 @@ impl Validator for DateTimeValidator { } } +/// In lax mode, if the input is not a datetime, we try parsing the input as a date and add the "00:00:00" time. +/// +/// Ok(None) means that this is not relevant to datetimes (the input was not a date nor a string) +fn datetime_from_date<'data>(input: &'data impl Input<'data>) -> Result>, ValError> { + let either_date = match input.validate_date(false) { + Ok(val_match) => val_match.into_inner(), + // if the error was a parsing error, update the error type from DateParsing to DatetimeFromDateParsing + Err(ValError::LineErrors(mut line_errors)) => { + if line_errors.iter_mut().fold(false, |has_parsing_error, line_error| { + if let ErrorType::DateParsing { error, .. } = &mut line_error.error_type { + line_error.error_type = ErrorType::DatetimeFromDateParsing { + error: std::mem::take(error), + context: None, + }; + true + } else { + has_parsing_error + } + }) { + return Err(ValError::LineErrors(line_errors)); + } + return Ok(None); + } + // for any other error, don't return it + Err(_) => return Ok(None), + }; + + let zero_time = Time { + hour: 0, + minute: 0, + second: 0, + microsecond: 0, + tz_offset: Some(0), + }; + + let datetime = DateTime { + date: either_date.as_raw()?, + time: zero_time, + }; + Ok(Some(EitherDateTime::Raw(datetime))) +} + #[derive(Debug, Clone)] struct DateTimeConstraints { le: Option, diff --git a/tests/test_errors.py b/tests/test_errors.py index 88dcace8f..fd3f34d8f 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -332,6 +332,7 @@ def f(input_value, info): ('time_parsing', 'Input should be in a valid time format, foobar', {'error': 'foobar'}), ('datetime_type', 'Input should be a valid datetime', None), ('datetime_parsing', 'Input should be a valid datetime, foobar', {'error': 'foobar'}), + ('datetime_from_date_parsing', 'Input should be a valid datetime or date, foobar', {'error': 'foobar'}), ('datetime_object_invalid', 'Invalid datetime object, got foobar', {'error': 'foobar'}), ('datetime_past', 'Input should be in the past', None), ('datetime_future', 'Input should be in the future', None), diff --git a/tests/test_hypothesis.py b/tests/test_hypothesis.py index ea0d67f0d..f02d1b3a9 100644 --- a/tests/test_hypothesis.py +++ b/tests/test_hypothesis.py @@ -47,12 +47,12 @@ def test_datetime_binary(datetime_schema, data): except ValidationError as exc: assert exc.errors(include_url=False) == [ { - 'type': 'datetime_parsing', - 'loc': (), - 'msg': IsStr(regex='Input should be a valid datetime, .+'), - 'input': IsBytes(), 'ctx': {'error': IsStr()}, - } + 'input': IsBytes(), + 'loc': (), + 'msg': IsStr(regex='Input should be a valid datetime or date, .+'), + 'type': 'datetime_from_date_parsing', + }, ] diff --git a/tests/validators/test_datetime.py b/tests/validators/test_datetime.py index 89e9c1c53..1d4a216f9 100644 --- a/tests/validators/test_datetime.py +++ b/tests/validators/test_datetime.py @@ -19,6 +19,7 @@ [ (datetime(2022, 6, 8, 12, 13, 14), datetime(2022, 6, 8, 12, 13, 14)), (date(2022, 6, 8), datetime(2022, 6, 8)), + ('2022-01-01', datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), ('2022-06-08T12:13:14', datetime(2022, 6, 8, 12, 13, 14)), ('1000000000000', datetime(2001, 9, 9, 1, 46, 40, tzinfo=timezone.utc)), (b'2022-06-08T12:13:14', datetime(2022, 6, 8, 12, 13, 14)), @@ -36,8 +37,14 @@ (float('nan'), Err('Input should be a valid datetime, NaN values not permitted [type=datetime_parsing,')), (float('inf'), Err('Input should be a valid datetime, dates after 9999')), (float('-inf'), Err('Input should be a valid datetime, dates before 1600')), - ('-', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), - ('+', Err('Input should be a valid datetime, input is too short [type=datetime_parsing,')), + ('-', Err('Input should be a valid datetime or date, input is too short [type=datetime_from_date_parsing,')), + ('+', Err('Input should be a valid datetime or date, input is too short [type=datetime_from_date_parsing,')), + ( + '2022-02-30', + Err( + 'Input should be a valid datetime or date, day value is outside expected range [type=datetime_from_date_parsing,' + ), + ), ], ) def test_datetime(input_value, expected): @@ -119,7 +126,9 @@ def test_keep_tz_bound(): (1655205632.331557, datetime(2022, 6, 14, 11, 20, 32, microsecond=331557, tzinfo=timezone.utc)), ( '2022-06-08T12:13:14+24:00', - Err('Input should be a valid datetime, timezone offset must be less than 24 hours [type=datetime_parsing,'), + Err( + 'Input should be a valid datetime or date, unexpected extra characters at the end of the input [type=datetime_from_date_parsing,' + ), ), (True, Err('Input should be a valid datetime [type=datetime_type')), (None, Err('Input should be a valid datetime [type=datetime_type')), From ea3ec7ed558d63c59ed572677bb0edcf2fff7a53 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 19 Dec 2023 17:49:40 +0000 Subject: [PATCH 153/550] Sync 2.14 branch into main (#1127) Co-authored-by: sydney-runkle Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Co-authored-by: Samuel Colvin --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a9d1d17e..d65bd9896 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,7 +331,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.4" +version = "2.14.5" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 7b61e32bd..422059b2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.4" +version = "2.14.5" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 2cb87a62eb1e5f8c47b1df22ebf28da82aa13836 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 20 Dec 2023 17:07:17 +0000 Subject: [PATCH 154/550] fix memory leak with recursive definitions creating reference cycles (#1125) --- src/definitions.rs | 94 +++++++++---------- .../type_serializers/definitions.rs | 34 ++++--- src/validators/definitions.rs | 64 +++++++------ 3 files changed, 100 insertions(+), 92 deletions(-) diff --git a/src/definitions.rs b/src/definitions.rs index 4627fd2d1..46a77196d 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -8,7 +8,7 @@ use std::{ fmt::Debug, sync::{ atomic::{AtomicBool, Ordering}, - Arc, OnceLock, + Arc, OnceLock, Weak, }, }; @@ -28,47 +28,50 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -#[derive(Clone)] pub struct Definitions(AHashMap, Definition>); -/// Internal type which contains a definition to be filled -pub struct Definition(Arc>); - -struct DefinitionInner { - value: OnceLock, - name: LazyName, +struct Definition { + value: Arc>, + name: Arc, } /// Reference to a definition. pub struct DefinitionRef { - name: Arc, - value: Definition, + reference: Arc, + // We use a weak reference to the definition to avoid a reference cycle + // when recursive definitions are used. + value: Weak>, + name: Arc, } // DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone) impl Clone for DefinitionRef { fn clone(&self) -> Self { Self { - name: self.name.clone(), + reference: self.reference.clone(), value: self.value.clone(), + name: self.name.clone(), } } } impl DefinitionRef { pub fn id(&self) -> usize { - Arc::as_ptr(&self.value.0) as usize + Weak::as_ptr(&self.value) as usize } pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str { - match self.value.0.value.get() { - Some(value) => self.value.0.name.get_or_init(|| init(value)), + let Some(definition) = self.value.upgrade() else { + return "..."; + }; + match definition.get() { + Some(value) => self.name.get_or_init(|| init(value)), None => "...", } } - pub fn get(&self) -> Option<&T> { - self.value.0.value.get() + pub fn read(&self, f: impl FnOnce(Option<&T>) -> R) -> R { + f(self.value.upgrade().as_ref().and_then(|value| value.get())) } } @@ -96,15 +99,9 @@ impl Debug for Definitions { } } -impl Clone for Definition { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - impl Debug for Definition { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0.value.get() { + match self.value.get() { Some(value) => value.fmt(f), None => "...".fmt(f), } @@ -113,7 +110,7 @@ impl Debug for Definition { impl PyGcTraverse for DefinitionRef { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Some(value) = self.value.0.value.get() { + if let Some(value) = self.value.upgrade().as_ref().and_then(|v| v.get()) { value.py_gc_traverse(visit)?; } Ok(()) @@ -123,7 +120,7 @@ impl PyGcTraverse for DefinitionRef { impl PyGcTraverse for Definitions { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { for value in self.0.values() { - if let Some(value) = value.0.value.get() { + if let Some(value) = value.value.get() { value.py_gc_traverse(visit)?; } } @@ -131,7 +128,7 @@ impl PyGcTraverse for Definitions { } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct DefinitionsBuilder { definitions: Definitions, } @@ -148,45 +145,48 @@ impl DefinitionsBuilder { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - let name = Arc::new(reference.to_string()); - let value = match self.definitions.0.entry(name.clone()) { + let reference = Arc::new(reference.to_string()); + let value = match self.definitions.0.entry(reference.clone()) { Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner { - value: OnceLock::new(), - name: LazyName::new(), - }))), + Entry::Vacant(entry) => entry.insert(Definition { + value: Arc::new(OnceLock::new()), + name: Arc::new(LazyName::new()), + }), }; DefinitionRef { - name, - value: value.clone(), + reference, + value: Arc::downgrade(&value.value), + name: value.name.clone(), } } /// Add a definition, returning the ReferenceId that maps to it pub fn add_definition(&mut self, reference: String, value: T) -> PyResult> { - let name = Arc::new(reference); - let value = match self.definitions.0.entry(name.clone()) { + let reference = Arc::new(reference); + let value = match self.definitions.0.entry(reference.clone()) { Entry::Occupied(entry) => { let definition = entry.into_mut(); - match definition.0.value.set(value) { - Ok(()) => definition.clone(), - Err(_) => return py_schema_err!("Duplicate ref: `{}`", name), + match definition.value.set(value) { + Ok(()) => definition, + Err(_) => return py_schema_err!("Duplicate ref: `{}`", reference), } } - Entry::Vacant(entry) => entry - .insert(Definition(Arc::new(DefinitionInner { - value: OnceLock::from(value), - name: LazyName::new(), - }))) - .clone(), + Entry::Vacant(entry) => entry.insert(Definition { + value: Arc::new(OnceLock::from(value)), + name: Arc::new(LazyName::new()), + }), }; - Ok(DefinitionRef { name, value }) + Ok(DefinitionRef { + reference, + value: Arc::downgrade(&value.value), + name: value.name.clone(), + }) } /// Consume this Definitions into a vector of items, indexed by each items ReferenceId pub fn finish(self) -> PyResult> { for (reference, def) in &self.definitions.0 { - if def.0.value.get().is_none() { + if def.value.get().is_none() { return py_schema_err!("Definitions error: definition `{}` was never filled", reference); } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index b7bf63365..99dae5bcd 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -68,15 +68,17 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let comb_serializer = self.definition.get().unwrap(); - let value_id = extra.rec_guard.add(value, self.definition.id())?; - let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + self.definition.read(|comb_serializer| { + let comb_serializer = comb_serializer.unwrap(); + let value_id = extra.rec_guard.add(value, self.definition.id())?; + let r = comb_serializer.to_python(value, include, exclude, extra); + extra.rec_guard.pop(value_id, self.definition.id()); + r + }) } fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { - self.definition.get().unwrap().json_key(key, extra) + self.definition.read(|s| s.unwrap().json_key(key, extra)) } fn serde_serialize( @@ -87,14 +89,16 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let comb_serializer = self.definition.get().unwrap(); - let value_id = extra - .rec_guard - .add(value, self.definition.id()) - .map_err(py_err_se_err)?; - let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + self.definition.read(|comb_serializer| { + let comb_serializer = comb_serializer.unwrap(); + let value_id = extra + .rec_guard + .add(value, self.definition.id()) + .map_err(py_err_se_err)?; + let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); + extra.rec_guard.pop(value_id, self.definition.id()); + r + }) } fn get_name(&self) -> &str { @@ -102,6 +106,6 @@ impl TypeSerializer for DefinitionRefSerializer { } fn retry_with_lax_check(&self) -> bool { - self.definition.get().unwrap().retry_with_lax_check() + self.definition.read(|s| s.unwrap().retry_with_lax_check()) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 7297bd27a..0b5f78c10 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -73,23 +73,25 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult { - let validator = self.definition.get().unwrap(); - if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + self.definition.read(|validator| { + let validator = validator.unwrap(); + if let Some(id) = input.identity() { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + } + let output = validator.validate(py, input, state); + state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.decr_depth(); + output } - let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output + } else { + validator.validate(py, input, state) } - } else { - validator.validate(py, input, state) - } + }) } fn validate_assignment<'data>( @@ -100,23 +102,25 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult { - let validator = self.definition.get().unwrap(); - if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + self.definition.read(|validator| { + let validator = validator.unwrap(); + if let Some(id) = obj.identity() { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + } + let output = validator.validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.decr_depth(); + output } - let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output + } else { + validator.validate_assignment(py, obj, field_name, field_value, state) } - } else { - validator.validate_assignment(py, obj, field_name, field_value, state) - } + }) } fn get_name(&self) -> &str { From 5c896fe51be476db62b84d60fc9088f7602344f1 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 21 Dec 2023 02:07:16 -0700 Subject: [PATCH 155/550] Support indirect definition references (#1130) --- python/pydantic_core/core_schema.py | 7 +++++-- tests/validators/test_definitions.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index ee305284a..5c54d5cc8 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3576,12 +3576,13 @@ def definitions_schema(schema: CoreSchema, definitions: list[CoreSchema]) -> Def class DefinitionReferenceSchema(TypedDict, total=False): type: Required[Literal['definition-ref']] schema_ref: Required[str] + ref: str metadata: Any serialization: SerSchema def definition_reference_schema( - schema_ref: str, metadata: Any = None, serialization: SerSchema | None = None + schema_ref: str, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None ) -> DefinitionReferenceSchema: """ Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive @@ -3606,7 +3607,9 @@ def definition_reference_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ - return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization) + return _dict_not_none( + type='definition-ref', schema_ref=schema_ref, ref=ref, metadata=metadata, serialization=serialization + ) MYPY = False diff --git a/tests/validators/test_definitions.py b/tests/validators/test_definitions.py index d742da5dd..606c8a6d9 100644 --- a/tests/validators/test_definitions.py +++ b/tests/validators/test_definitions.py @@ -140,3 +140,13 @@ def test_use_after(): ) ) assert v.validate_python(['1', '2']) == (1, 2) + + +def test_definition_chain(): + v = SchemaValidator( + core_schema.definitions_schema( + core_schema.definition_reference_schema('foo'), + [core_schema.definition_reference_schema(ref='foo', schema_ref='bar'), core_schema.int_schema(ref='bar')], + ), + ) + assert v.validate_python('1') == 1 From d706aa45a67dd519a363d5c31c9b39c898b68f24 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 21 Dec 2023 18:42:01 +0000 Subject: [PATCH 156/550] fix "run-on-arch" tests (#1131) --- .github/workflows/ci.yml | 4 ++-- tests/requirements.txt | 4 +++- tests/test_docstrings.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1f6c0236..6075bb9ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -643,11 +643,11 @@ jobs: if command -v apt-get &> /dev/null; then echo "installing python & pip with apt-get..." apt-get update - apt-get install -y --no-install-recommends python3 python3-pip python3-venv + apt-get install -y --no-install-recommends python3 python3-pip python3-venv git else echo "installing python & pip with apk..." apk update - apk add python3 py3-pip + apk add python3 py3-pip git fi run: | set -x diff --git a/tests/requirements.txt b/tests/requirements.txt index 67e8e002b..5e6efed59 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -9,7 +9,9 @@ pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implement pytest==7.4.3 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' -pytest-examples==0.0.10 +# pytest-examples currently depends on aiohttp via black; we don't want to build +# it on platforms like aarch64 musllinux in CI +pytest-examples==0.0.10; implementation_name == "cpython" and platform_machine == 'x86_64' pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index fb729723f..25f5bf1e4 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -12,7 +12,7 @@ def find_examples(*_directories): return [] -@pytest.mark.skipif(sys.platform not in {'linux', 'darwin'}, reason='Only on linux and macos') +@pytest.mark.skipif(CodeExample is None or sys.platform not in {'linux', 'darwin'}, reason='Only on linux and macos') @pytest.mark.parametrize('example', find_examples('python/pydantic_core/core_schema.py'), ids=str) def test_docstrings(example: CodeExample, eval_example: EvalExample): eval_example.set_config(quotes='single') @@ -25,7 +25,7 @@ def test_docstrings(example: CodeExample, eval_example: EvalExample): eval_example.run_print_check(example) -@pytest.mark.skipif(sys.platform not in {'linux', 'darwin'}, reason='Only on linux and macos') +@pytest.mark.skipif(CodeExample is None or sys.platform not in {'linux', 'darwin'}, reason='Only on linux and macos') @pytest.mark.parametrize('example', find_examples('README.md'), ids=str) def test_readme(example: CodeExample, eval_example: EvalExample): eval_example.set_config(line_length=100, quotes='single') From 341b1bd8551cd474dc72dd10d53f59ca6373a341 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 8 Jan 2024 14:25:20 +0000 Subject: [PATCH 157/550] drop Python 3.7, and PyPy 3.7 and 3.8 (#1129) --- .github/workflows/ci.yml | 23 +++++++++-------------- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 2 +- generate_self_schema.py | 2 +- pyproject.toml | 3 +-- 6 files changed, 14 insertions(+), 20 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6075bb9ba..484713730 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,14 +65,11 @@ jobs: fail-fast: false matrix: python-version: - - '3.7' - '3.8' - '3.9' - '3.10' - '3.11' - '3.12' - - 'pypy3.7' - - 'pypy3.8' - 'pypy3.9' - 'pypy3.10' @@ -385,19 +382,19 @@ jobs: - os: linux manylinux: auto target: armv7 - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 + interpreter: 3.8 3.9 3.10 3.11 3.12 - os: linux manylinux: auto target: ppc64le - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 + interpreter: 3.8 3.9 3.10 3.11 3.12 - os: linux manylinux: auto target: s390x - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 + interpreter: 3.8 3.9 3.10 3.11 3.12 - os: linux manylinux: auto target: x86_64 - interpreter: pypy3.7 pypy3.8 pypy3.9 pypy3.10 + interpreter: pypy3.9 pypy3.10 # musllinux - os: linux @@ -414,7 +411,7 @@ jobs: target: x86_64 - os: macos target: aarch64 - interpreter: 3.7 3.8 3.9 pypy3.8 pypy3.9 pypy3.10 + interpreter: 3.8 3.9 pypy3.9 pypy3.10 # windows; # x86_64 pypy builds are not PGO optimized @@ -422,11 +419,11 @@ jobs: # aarch64 only 3.11 and up, also not PGO optimized - os: windows target: x86_64 - interpreter: pypy3.8 pypy3.9 pypy3.10 + interpreter: pypy3.9 pypy3.10 - os: windows target: i686 python-architecture: x86 - interpreter: 3.7 3.8 3.9 3.10 3.11 3.12 + interpreter: 3.8 3.9 3.10 3.11 3.12 - os: windows target: aarch64 interpreter: 3.11 3.12 @@ -451,7 +448,7 @@ jobs: with: target: ${{ matrix.target }} manylinux: ${{ matrix.manylinux == 'manylinux' && 'auto' || matrix.manylinux }} - args: --release --out dist --interpreter ${{ matrix.interpreter || '3.7 3.8 3.9 3.10 3.11 3.12 pypy3.7 pypy3.8 pypy3.9 pypy3.10' }} + args: --release --out dist --interpreter ${{ matrix.interpreter || '3.8 3.9 3.10 3.11 3.12 pypy3.9 pypy3.10' }} rust-toolchain: stable docker-options: -e CI @@ -472,7 +469,7 @@ jobs: fail-fast: false matrix: os: [linux, windows, macos] - interpreter: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] + interpreter: ['3.8', '3.9', '3.10', '3.11', '3.12'] include: # standard runners with override for macos arm - os: linux @@ -484,8 +481,6 @@ jobs: runs-on: macos-latest-xlarge exclude: # macos arm only supported from 3.10 and up - - os: macos - interpreter: '3.7' - os: macos interpreter: '3.8' - os: macos diff --git a/Cargo.lock b/Cargo.lock index d65bd9896..f85c39729 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,7 +331,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.14.5" +version = "2.15.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 422059b2b..2c798fc52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.14.5" +version = "2.15.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" diff --git a/README.md b/README.md index 9b2d9df8f..ec625ad26 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ except ValidationError as e: You'll need rust stable [installed](https://rustup.rs/), or rust nightly if you want to generate accurate coverage. -With rust and python 3.7+ installed, compiling pydantic-core should be possible with roughly the following: +With rust and python 3.8+ installed, compiling pydantic-core should be possible with roughly the following: ```bash # clone this repo or your fork diff --git a/generate_self_schema.py b/generate_self_schema.py index 2c190bbad..4e1235b1a 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -191,7 +191,7 @@ def eval_forward_ref(type_: Any) -> Any: try: return type_._evaluate(core_schema.__dict__, None, set()) except TypeError: - # for older python (3.7 at least) + # for Python 3.8 return type_._evaluate(core_schema.__dict__, None) diff --git a/pyproject.toml b/pyproject.toml index a40f160f0..5cdbca806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = 'maturin' [project] name = 'pydantic_core' -requires-python = '>=3.7' +requires-python = '>=3.8' authors = [ {name = 'Samuel Colvin', email = 's@muelcolvin.com'} ] @@ -16,7 +16,6 @@ classifiers = [ 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', From 1c7554faeefd9db668e7a7535bcac010ae1cf080 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:45:49 -0600 Subject: [PATCH 158/550] Bump ahash from 0.8.6 to 0.8.7 (#1135) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f85c39729..0618e11cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", "getrandom", @@ -714,18 +714,18 @@ checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] name = "zerocopy" -version = "0.7.20" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd66a62464e3ffd4e37bd09950c2b9dd6c4f8767380fabba0d523f9a775bc85a" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.20" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "255c4596d41e6916ced49cfafea18727b24d67878fa180ddfd69b9df34fd1726" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2c798fc52..95bd03ff2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ enum_dispatch = "0.3.8" serde = { version = "1.0.193", features = ["derive"] } speedate = "0.13.0" smallvec = "1.11.2" -ahash = "0.8.6" +ahash = "0.8.7" url = "2.5.0" # idna is already required by url, added here to be explicit idna = "0.4.0" From bb9c84156cb808f259d4a3bd68075afd7eaebc9c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:46:19 -0600 Subject: [PATCH 159/550] Bump pytest from 7.4.3 to 7.4.4 (#1143) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 5e6efed59..16942c3e5 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -6,7 +6,7 @@ hypothesis==6.79.4 git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.3 +pytest==7.4.4 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' # pytest-examples currently depends on aiohttp via black; we don't want to build From be727a9524d91eee444223088942832a0d915bc6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:46:36 -0600 Subject: [PATCH 160/550] Bump actions/setup-python from 4 to 5 (#1138) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 24 ++++++++++++------------ .github/workflows/codspeed.yml | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 484713730..278099172 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: - run: rustup component add llvm-tools-preview - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -87,7 +87,7 @@ jobs: key: test-v3 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -124,7 +124,7 @@ jobs: key: ${{ matrix.os }}-v1 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -155,7 +155,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -188,7 +188,7 @@ jobs: path: pydantic-core - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -230,7 +230,7 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: '3.11' @@ -276,7 +276,7 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: '3.10' @@ -291,7 +291,7 @@ jobs: - name: set up python id: setup-python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -433,7 +433,7 @@ jobs: - uses: actions/checkout@v4 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' architecture: ${{ matrix.python-architecture || 'x64' }} @@ -491,7 +491,7 @@ jobs: - uses: actions/checkout@v4 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.interpreter }} @@ -669,7 +669,7 @@ jobs: - uses: actions/checkout@v4 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -699,7 +699,7 @@ jobs: - uses: actions/checkout@v4 - name: set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 697815a16..5d69d4d5f 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -16,7 +16,7 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: '3.12' From 5f0ad2c47f635ec275bc920e5608541117cf10cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:56:00 -0600 Subject: [PATCH 161/550] Bump actions/upload-artifact from 3 to 4 (#1139) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 278099172..245dc5dcf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -329,7 +329,7 @@ jobs: ls -lh dist/ ls -l dist/ - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: wasm_wheels path: dist @@ -356,7 +356,7 @@ jobs: command: sdist args: --out dist rust-toolchain: stable - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: pypi_files path: dist @@ -456,7 +456,7 @@ jobs: - run: twine check --strict dist/* - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: pypi_files path: dist @@ -551,7 +551,7 @@ jobs: - run: ${{ matrix.ls || 'ls -lh' }} dist/ - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: pypi_files_pgo path: dist From 3757e21806fab2fd92200a5051602a5b497bdc24 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:56:21 -0600 Subject: [PATCH 162/550] Bump actions/download-artifact from 3 to 4 (#1137) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 245dc5dcf..1bf0698a1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -564,7 +564,7 @@ jobs: - uses: actions/checkout@v4 - name: get dist artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files path: dist @@ -576,7 +576,7 @@ jobs: echo "`ls dist | wc -l` files" - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files_pgo path: dist @@ -616,13 +616,13 @@ jobs: - uses: actions/checkout@v4 - name: get dist artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files path: dist - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files_pgo path: dist @@ -674,13 +674,13 @@ jobs: python-version: '3.11' - name: get dist artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files path: dist - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files_pgo path: dist @@ -709,13 +709,13 @@ jobs: run: python .github/check_version.py - name: get dist artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files path: dist - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: pypi_files_pgo path: dist @@ -729,7 +729,7 @@ jobs: TWINE_PASSWORD: ${{ secrets.pypi_token }} - name: get wasm dist artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: wasm_wheels path: wasm From 22e5fba2f2f10efc999e6c9c28c3b2a713cf68b8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:56:53 -0600 Subject: [PATCH 163/550] Bump mypy from 1.6.1 to 1.8.0 (#1141) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index c97e47a1b..e1e5e4268 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ griffe==0.38.0 pyright==1.1.339 ruff==0.1.7 -mypy==1.6.1 +mypy==1.8.0 From ab7dae5593cc7c1a7cd841063b5a94e42c032ed0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 08:57:13 -0600 Subject: [PATCH 164/550] Bump hypothesis from 6.79.4 to 6.92.5 (#1145) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 16942c3e5..01040a24e 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ coverage==7.2.7 dirty-equals==0.6.0 -hypothesis==6.79.4 +hypothesis==6.92.5 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is # released which removes use of deprecated utcfromtimestamp git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f From ccf13e26cb49305e20efbb75334237606a825749 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:00:01 -0600 Subject: [PATCH 165/550] Bump pyo3 from 0.20.0 to 0.20.1 (#1136) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> --- Cargo.lock | 20 ++++++++++---------- Cargo.toml | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0618e11cd..a88bed6bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -356,9 +356,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" dependencies = [ "cfg-if", "indoc", @@ -374,9 +374,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" dependencies = [ "once_cell", "python3-dll-a", @@ -385,9 +385,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" dependencies = [ "libc", "pyo3-build-config", @@ -395,9 +395,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -407,9 +407,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 95bd03ff2..55c3dd129 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ include = [ ] [dependencies] -pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] } +pyo3 = { version = "0.20.1", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" @@ -68,12 +68,12 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version = "0.20.0", features = ["auto-initialize"] } +pyo3 = { version = "0.20.1", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy -pyo3-build-config = { version = "0.20.0" } +pyo3-build-config = { version = "0.20.1" } [lints.clippy] dbg_macro = "warn" From d1e18e4ac19d7f3ccbd0bb013d319ea42ffe015e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:03:01 -0600 Subject: [PATCH 166/550] Bump serde_json from 1.0.108 to 1.0.109 (#1134) @davidhewitt mentioned that we can ignore these performance regressions for now, as they're caused by "just volatility due to PyO3's internal datastructure" that will soon be reworked. Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a88bed6bb..b37062696 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -513,9 +513,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.108" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index 55c3dd129..f332f1327 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ pyo3 = { version = "0.20.1", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" -serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.193", features = ["derive"] } speedate = "0.13.0" From 9a2f6204e0f8680421185bf3081ba09b135d9dbf Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 10:08:55 -0600 Subject: [PATCH 167/550] Bump coverage from 7.2.7 to 7.4.0 (#1142) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 01040a24e..dc211e9c7 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,4 +1,4 @@ -coverage==7.2.7 +coverage==7.4.0 dirty-equals==0.6.0 hypothesis==6.92.5 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is From f3d0cc5240101043f755bc65d2b4dbf2b679de6f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 8 Jan 2024 18:36:40 +0200 Subject: [PATCH 168/550] Rework `PYDANTIC_ERRORS_OMIT_URL` to `PYDANTIC_ERRORS_INCLUDE_URL` (#1123) --- python/pydantic_core/_pydantic_core.pyi | 12 ++++++ src/errors/validation_exception.rs | 32 ++++++++++++---- tests/test_errors.py | 50 +++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 382a6c804..b8c1f4e94 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -787,6 +787,18 @@ class ValidationError(ValueError): a JSON string. """ + def __repr__(self) -> str: + """ + A string representation of the validation error. + + Whether or not documentation URLs are included in the repr is controlled by the + environment variable `PYDANTIC_ERRORS_INCLUDE_URL` being set to `1` or + `true`; by default, URLs are shown. + + Due to implementation details, this environment variable can only be set once, + before the first validation error is created. + """ + @final class PydanticCustomError(ValueError): def __new__( diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index e77b21974..95090c1ac 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -193,15 +193,31 @@ impl ValidationError { static URL_ENV_VAR: GILOnceCell = GILOnceCell::new(); -fn _get_include_url_env() -> bool { - match std::env::var("PYDANTIC_ERRORS_OMIT_URL") { - Ok(val) => val.is_empty(), - Err(_) => true, - } -} - fn include_url_env(py: Python) -> bool { - *URL_ENV_VAR.get_or_init(py, _get_include_url_env) + *URL_ENV_VAR.get_or_init(py, || { + // Check the legacy env var first. + // Using `var_os` here instead of `var` because we don't care about + // the value (or whether we're able to decode it as UTF-8), just + // whether it exists (and if it does, whether it's non-empty). + match std::env::var_os("PYDANTIC_ERRORS_OMIT_URL") { + Some(val) => { + // We don't care whether warning succeeded or not, hence the assignment + let _ = PyErr::warn( + py, + py.get_type::(), + "PYDANTIC_ERRORS_OMIT_URL is deprecated, use PYDANTIC_ERRORS_INCLUDE_URL instead", + 1, + ); + // If OMIT_URL exists but is empty, we include the URL: + val.is_empty() + } + // If the legacy env var doesn't exist, check the documented one: + None => match std::env::var("PYDANTIC_ERRORS_INCLUDE_URL") { + Ok(val) => val == "1" || val.to_lowercase() == "true", + Err(_) => true, + }, + } + }) } static URL_PREFIX: GILOnceCell = GILOnceCell::new(); diff --git a/tests/test_errors.py b/tests/test_errors.py index fd3f34d8f..18821b98c 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -1,6 +1,8 @@ import enum +import os import pickle import re +import subprocess import sys from decimal import Decimal from typing import Any, Optional @@ -1090,3 +1092,51 @@ def test_validation_error_pickle() -> None: original = exc_info.value roundtripped = pickle.loads(pickle.dumps(original)) assert original.errors() == roundtripped.errors() + + +@pytest.mark.skipif('PYDANTIC_ERRORS_INCLUDE_URL' in os.environ, reason="can't test when envvar is set") +def test_errors_include_url() -> None: + s = SchemaValidator({'type': 'int'}) + with pytest.raises(ValidationError) as exc_info: + s.validate_python('definitely not an int') + assert 'https://errors.pydantic.dev' in repr(exc_info.value) + + +@pytest.mark.skipif(sys.platform == 'emscripten', reason='no subprocesses on emscripten') +@pytest.mark.parametrize( + ('env_var', 'env_var_value', 'expected_to_have_url'), + [ + ('PYDANTIC_ERRORS_INCLUDE_URL', None, True), + ('PYDANTIC_ERRORS_INCLUDE_URL', '1', True), + ('PYDANTIC_ERRORS_INCLUDE_URL', 'True', True), + ('PYDANTIC_ERRORS_INCLUDE_URL', 'no', False), + ('PYDANTIC_ERRORS_INCLUDE_URL', '0', False), + # Legacy environment variable, will raise a deprecation warning: + ('PYDANTIC_ERRORS_OMIT_URL', '1', False), + ('PYDANTIC_ERRORS_OMIT_URL', None, True), + ], +) +def test_errors_include_url_envvar(env_var, env_var_value, expected_to_have_url) -> None: + """ + Test the `PYDANTIC_ERRORS_INCLUDE_URL` environment variable. + + Since it can only be set before `ValidationError.__repr__()` is first called, + we need to spawn a subprocess to test it. + """ + code = "import pydantic_core; pydantic_core.SchemaValidator({'type': 'int'}).validate_python('ooo')" + env = os.environ.copy() + env.pop('PYDANTIC_ERRORS_OMIT_URL', None) # in case the ambient environment has it set + if env_var_value is not None: + env[env_var] = env_var_value + env['PYTHONDEVMODE'] = '1' # required to surface the deprecation warning + result = subprocess.run( + [sys.executable, '-c', code], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding='utf-8', + env=env, + ) + assert result.returncode == 1 + if 'PYDANTIC_ERRORS_OMIT_URL' in env: + assert 'PYDANTIC_ERRORS_OMIT_URL is deprecated' in result.stdout + assert ('https://errors.pydantic.dev' in result.stdout) == expected_to_have_url From e3ae7f643a7f8656eafdbb8854aad39ede756309 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 8 Jan 2024 21:03:27 -0600 Subject: [PATCH 169/550] Support serialization mode specification from model config and `SerializationConfig` (#1122) Co-authored-by: David Hewitt --- python/pydantic_core/_pydantic_core.pyi | 4 + src/errors/validation_exception.rs | 2 +- src/serializers/config.rs | 144 ++++++++---------- src/serializers/extra.rs | 4 +- src/serializers/infer.rs | 11 +- src/serializers/mod.rs | 10 +- src/serializers/type_serializers/bytes.rs | 17 ++- src/serializers/type_serializers/timedelta.rs | 26 ++-- tests/serializers/test_bytes.py | 36 ++++- 9 files changed, 143 insertions(+), 111 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index b8c1f4e94..a7b727f86 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -355,6 +355,7 @@ def to_json( round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, ) -> bytes: @@ -373,6 +374,7 @@ def to_json( round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, @@ -414,6 +416,7 @@ def to_jsonable_python( round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, ) -> Any: @@ -432,6 +435,7 @@ def to_jsonable_python( round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 95090c1ac..24626df29 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -319,7 +319,7 @@ impl ValidationError { include_context: bool, include_input: bool, ) -> PyResult<&'py PyString> { - let state = SerializationState::new("iso8601", "utf8")?; + let state = SerializationState::new("iso8601", "utf8", "constants")?; let extra = state.extra(py, &SerMode::Json, true, false, false, true, None); let serializer = ValidationErrorSerializer { py, diff --git a/src/serializers/config.rs b/src/serializers/config.rs index e83497f64..422ee4162 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -15,60 +15,98 @@ use crate::tools::SchemaDict; use super::errors::py_err_se_err; #[derive(Debug, Clone)] +#[allow(clippy::struct_field_names)] pub(crate) struct SerializationConfig { pub timedelta_mode: TimedeltaMode, pub bytes_mode: BytesMode, + pub inf_nan_mode: InfNanMode, } impl SerializationConfig { pub fn from_config(config: Option<&PyDict>) -> PyResult { let timedelta_mode = TimedeltaMode::from_config(config)?; let bytes_mode = BytesMode::from_config(config)?; + let inf_nan_mode = InfNanMode::from_config(config)?; Ok(Self { timedelta_mode, bytes_mode, + inf_nan_mode, }) } - pub fn from_args(timedelta_mode: &str, bytes_mode: &str) -> PyResult { + pub fn from_args(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { Ok(Self { timedelta_mode: TimedeltaMode::from_str(timedelta_mode)?, bytes_mode: BytesMode::from_str(bytes_mode)?, + inf_nan_mode: InfNanMode::from_str(inf_nan_mode)?, }) } } -#[derive(Default, Debug, Clone)] -pub(crate) enum TimedeltaMode { - #[default] - Iso8601, - Float, +pub trait FromConfig { + fn from_config(config: Option<&PyDict>) -> PyResult + where + Self: Sized; } -impl FromStr for TimedeltaMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "iso8601" => Ok(Self::Iso8601), - "float" => Ok(Self::Float), - s => py_schema_err!( - "Invalid timedelta serialization mode: `{}`, expected `iso8601` or `float`", - s - ), +macro_rules! serialization_mode { + ($name:ident, $config_key:expr, $($variant:ident => $value:expr),* $(,)?) => { + #[derive(Default, Debug, Clone, PartialEq, Eq)] + pub(crate) enum $name { + #[default] + $($variant,)* } - } + + impl FromStr for $name { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + $($value => Ok(Self::$variant),)* + s => py_schema_err!( + concat!("Invalid ", stringify!($name), " serialization mode: `{}`, expected ", $($value, " or "),*), + s + ), + } + } + } + + impl FromConfig for $name { + fn from_config(config: Option<&PyDict>) -> PyResult { + let Some(config_dict) = config else { + return Ok(Self::default()); + }; + let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), $config_key))?; + raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) + } + } + + }; } -impl TimedeltaMode { - pub fn from_config(config: Option<&PyDict>) -> PyResult { - let Some(config_dict) = config else { - return Ok(Self::default()); - }; - let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_timedelta"))?; - raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) - } +serialization_mode! { + TimedeltaMode, + "ser_json_timedelta", + Iso8601 => "iso8601", + Float => "float", +} + +serialization_mode! { + BytesMode, + "ser_json_bytes", + Utf8 => "utf8", + Base64 => "base64", + Hex => "hex", +} + +serialization_mode! { + InfNanMode, + "ser_json_inf_nan", + Null => "null", + Constants => "constants", +} +impl TimedeltaMode { fn total_seconds(py_timedelta: &PyDelta) -> PyResult<&PyAny> { py_timedelta.call_method0(intern!(py_timedelta.py(), "total_seconds")) } @@ -124,39 +162,7 @@ impl TimedeltaMode { } } -#[derive(Default, Debug, Clone)] -pub(crate) enum BytesMode { - #[default] - Utf8, - Base64, - Hex, -} - -impl FromStr for BytesMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "utf8" => Ok(Self::Utf8), - "base64" => Ok(Self::Base64), - "hex" => Ok(Self::Hex), - s => py_schema_err!( - "Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`", - s - ), - } - } -} - impl BytesMode { - pub fn from_config(config: Option<&PyDict>) -> PyResult { - let Some(config_dict) = config else { - return Ok(Self::default()); - }; - let raw_mode = config_dict.get_as::<&str>(intern!(config_dict.py(), "ser_json_bytes"))?; - raw_mode.map_or_else(|| Ok(Self::default()), Self::from_str) - } - pub fn bytes_to_string<'py>(&self, py: Python, bytes: &'py [u8]) -> PyResult> { match self { Self::Utf8 => from_utf8(bytes) @@ -190,28 +196,6 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr { } } -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub(crate) enum InfNanMode { - #[default] - Null, - Constants, -} - -impl FromStr for InfNanMode { - type Err = PyErr; - - fn from_str(s: &str) -> Result { - match s { - "null" => Ok(Self::Null), - "constants" => Ok(Self::Constants), - s => py_schema_err!( - "Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`", - s - ), - } - } -} - impl FromPyObject<'_> for InfNanMode { fn extract(ob: &'_ PyAny) -> PyResult { let s = ob.extract::<&str>()?; diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 7a9b84704..37307055e 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -21,10 +21,10 @@ pub(crate) struct SerializationState { } impl SerializationState { - pub fn new(timedelta_mode: &str, bytes_mode: &str) -> PyResult { + pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { let warnings = CollectWarnings::new(false); let rec_guard = SerRecursionGuard::default(); - let config = SerializationConfig::from_args(timedelta_mode, bytes_mode)?; + let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?; Ok(Self { warnings, rec_guard, diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 265967c57..13c20062b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -10,6 +10,7 @@ use pyo3::types::{ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; use crate::input::{EitherTimedelta, Int}; +use crate::serializers::config::InfNanMode; use crate::serializers::errors::SERIALIZATION_ERR_MARKER; use crate::serializers::filter::SchemaFilter; use crate::serializers::shared::{PydanticSerializer, TypeSerializer}; @@ -120,10 +121,16 @@ pub(crate) fn infer_to_python_known( let value = match extra.mode { SerMode::Json => match ob_type { // `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types - ObType::None | ObType::Bool | ObType::Int | ObType::Float | ObType::Str => value.into_py(py), + ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py), // have to do this to make sure subclasses of for example str are upcast to `str` ObType::IntSubclass => extract_i64(value)?.into_py(py), - ObType::FloatSubclass => value.extract::()?.into_py(py), + ObType::Float | ObType::FloatSubclass => { + let v = value.extract::()?; + if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null { + return Ok(py.None()); + } + v.into_py(py) + } ObType::Decimal => value.to_string().into_py(py), ObType::StrSubclass => value.extract::<&str>()?.into_py(py), ObType::Bytes => extra diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index e9208a510..8159691cb 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -213,7 +213,7 @@ impl SchemaSerializer { #[pyfunction] #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8", - serialize_unknown = false, fallback = None))] + inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] pub fn to_json( py: Python, value: &PyAny, @@ -225,10 +225,11 @@ pub fn to_json( round_trip: bool, timedelta_mode: &str, bytes_mode: &str, + inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, ) -> PyResult { - let state = SerializationState::new(timedelta_mode, bytes_mode)?; + let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( py, &SerMode::Json, @@ -248,7 +249,7 @@ pub fn to_json( #[allow(clippy::too_many_arguments)] #[pyfunction] #[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, - timedelta_mode = "iso8601", bytes_mode = "utf8", serialize_unknown = false, fallback = None))] + timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] pub fn to_jsonable_python( py: Python, value: &PyAny, @@ -259,10 +260,11 @@ pub fn to_jsonable_python( round_trip: bool, timedelta_mode: &str, bytes_mode: &str, + inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, ) -> PyResult { - let state = SerializationState::new(timedelta_mode, bytes_mode)?; + let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( py, &SerMode::Json, diff --git a/src/serializers/type_serializers/bytes.rs b/src/serializers/type_serializers/bytes.rs index 67dbe794b..bd354aa75 100644 --- a/src/serializers/type_serializers/bytes.rs +++ b/src/serializers/type_serializers/bytes.rs @@ -4,6 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict}; use crate::definitions::DefinitionsBuilder; +use crate::serializers::config::{BytesMode, FromConfig}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, @@ -11,17 +12,20 @@ use super::{ }; #[derive(Debug, Clone)] -pub struct BytesSerializer; +pub struct BytesSerializer { + bytes_mode: BytesMode, +} impl BuildSerializer for BytesSerializer { const EXPECTED_TYPE: &'static str = "bytes"; fn build( _schema: &PyDict, - _config: Option<&PyDict>, + config: Option<&PyDict>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - Ok(Self {}.into()) + let bytes_mode = BytesMode::from_config(config)?; + Ok(Self { bytes_mode }.into()) } } @@ -38,8 +42,7 @@ impl TypeSerializer for BytesSerializer { let py = value.py(); match value.downcast::() { Ok(py_bytes) => match extra.mode { - SerMode::Json => extra - .config + SerMode::Json => self .bytes_mode .bytes_to_string(py, py_bytes.as_bytes()) .map(|s| s.into_py(py)), @@ -54,7 +57,7 @@ impl TypeSerializer for BytesSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match key.downcast::() { - Ok(py_bytes) => extra.config.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()), + Ok(py_bytes) => self.bytes_mode.bytes_to_string(key.py(), py_bytes.as_bytes()), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) @@ -71,7 +74,7 @@ impl TypeSerializer for BytesSerializer { extra: &Extra, ) -> Result { match value.downcast::() { - Ok(py_bytes) => extra.config.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer), + Ok(py_bytes) => self.bytes_mode.serialize_bytes(py_bytes.as_bytes(), serializer), Err(_) => { extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; infer_serialize(value, serializer, include, exclude, extra) diff --git a/src/serializers/type_serializers/timedelta.rs b/src/serializers/type_serializers/timedelta.rs index baadbe0aa..042a8ffea 100644 --- a/src/serializers/type_serializers/timedelta.rs +++ b/src/serializers/type_serializers/timedelta.rs @@ -5,6 +5,7 @@ use pyo3::types::PyDict; use crate::definitions::DefinitionsBuilder; use crate::input::EitherTimedelta; +use crate::serializers::config::{FromConfig, TimedeltaMode}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, SerMode, @@ -12,17 +13,20 @@ use super::{ }; #[derive(Debug, Clone)] -pub struct TimeDeltaSerializer; +pub struct TimeDeltaSerializer { + timedelta_mode: TimedeltaMode, +} impl BuildSerializer for TimeDeltaSerializer { const EXPECTED_TYPE: &'static str = "timedelta"; fn build( _schema: &PyDict, - _config: Option<&PyDict>, + config: Option<&PyDict>, _definitions: &mut DefinitionsBuilder, ) -> PyResult { - Ok(Self {}.into()) + let timedelta_mode = TimedeltaMode::from_config(config)?; + Ok(Self { timedelta_mode }.into()) } } @@ -38,10 +42,7 @@ impl TypeSerializer for TimeDeltaSerializer { ) -> PyResult { match extra.mode { SerMode::Json => match EitherTimedelta::try_from(value) { - Ok(either_timedelta) => extra - .config - .timedelta_mode - .either_delta_to_json(value.py(), &either_timedelta), + Ok(either_timedelta) => self.timedelta_mode.either_delta_to_json(value.py(), &either_timedelta), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), value, extra)?; infer_to_python(value, include, exclude, extra) @@ -53,7 +54,7 @@ impl TypeSerializer for TimeDeltaSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match EitherTimedelta::try_from(key) { - Ok(either_timedelta) => extra.config.timedelta_mode.json_key(key.py(), &either_timedelta), + Ok(either_timedelta) => self.timedelta_mode.json_key(key.py(), &either_timedelta), Err(_) => { extra.warnings.on_fallback_py(self.get_name(), key, extra)?; infer_json_key(key, extra) @@ -70,12 +71,9 @@ impl TypeSerializer for TimeDeltaSerializer { extra: &Extra, ) -> Result { match EitherTimedelta::try_from(value) { - Ok(either_timedelta) => { - extra - .config - .timedelta_mode - .timedelta_serialize(value.py(), &either_timedelta, serializer) - } + Ok(either_timedelta) => self + .timedelta_mode + .timedelta_serialize(value.py(), &either_timedelta, serializer), Err(_) => { extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; infer_serialize(value, serializer, include, exclude, extra) diff --git a/tests/serializers/test_bytes.py b/tests/serializers/test_bytes.py index 13849bed0..cc2d44785 100644 --- a/tests/serializers/test_bytes.py +++ b/tests/serializers/test_bytes.py @@ -4,7 +4,7 @@ import pytest -from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema +from pydantic_core import PydanticSerializationError, SchemaSerializer, core_schema, to_json def test_bytes(): @@ -126,3 +126,37 @@ def test_any_bytes_base64(): assert s.to_json(b'foobar') == b'"Zm9vYmFy"' assert s.to_json({b'foobar': 123}) == b'{"Zm9vYmFy":123}' assert s.to_python({b'foobar': 123}, mode='json') == {'Zm9vYmFy': 123} + + +class BasicModel: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +def test_bytes_mode_set_via_model_config_not_serializer_config(): + s = SchemaSerializer( + core_schema.model_schema( + BasicModel, + core_schema.model_fields_schema( + { + 'foo': core_schema.model_field(core_schema.bytes_schema()), + } + ), + config=core_schema.CoreConfig(ser_json_bytes='base64'), + ) + ) + + bm = BasicModel(foo=b'foobar') + assert s.to_python(bm) == {'foo': b'foobar'} + assert s.to_json(bm) == b'{"foo":"Zm9vYmFy"}' + assert s.to_python(bm, mode='json') == {'foo': 'Zm9vYmFy'} + + # assert doesn't override serializer config + # in V3, we can change the serialization settings provided to to_json to override model config settings, + # but that'd be a breaking change + BasicModel.__pydantic_serializer__ = s + assert to_json(bm, bytes_mode='utf8') == b'{"foo":"Zm9vYmFy"}' + + assert to_json({'foo': b'some bytes'}, bytes_mode='base64') == b'{"foo":"c29tZSBieXRlcw=="}' + assert to_json({'bar': bm}, bytes_mode='base64') == b'{"bar":{"foo":"Zm9vYmFy"}}' From 740e996e7c593eb5fdc60cf71b8643f1d269b17f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:22:54 +0000 Subject: [PATCH 170/550] Bump idna from 0.4.0 to 0.5.0 (#1101) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 14 ++------------ Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b37062696..2f17cec97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,16 +104,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "idna" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "idna" version = "0.5.0" @@ -336,7 +326,7 @@ dependencies = [ "ahash", "base64", "enum_dispatch", - "idna 0.4.0", + "idna", "jiter", "num-bigint", "pyo3", @@ -633,7 +623,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna", "percent-encoding", ] diff --git a/Cargo.toml b/Cargo.toml index f332f1327..66a8667e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ smallvec = "1.11.2" ahash = "0.8.7" url = "2.5.0" # idna is already required by url, added here to be explicit -idna = "0.4.0" +idna = "0.5.0" base64 = "0.21.5" num-bigint = "0.4.4" python3-dll-a = "0.2.7" From a1b934f20753b78ce74364945542a3db8e71969a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:26:45 +0000 Subject: [PATCH 171/550] Bump dirty-equals from 0.6.0 to 0.7.1.post0 (#1111) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index dc211e9c7..4297177d8 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,5 +1,5 @@ coverage==7.4.0 -dirty-equals==0.6.0 +dirty-equals==0.7.1.post0 hypothesis==6.92.5 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is # released which removes use of deprecated utcfromtimestamp From 4df7624c12b57dc47fe1de72acb2795c1631dbb9 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:41:55 -0700 Subject: [PATCH 172/550] Add unified tuple validator that can handle "variadic" tuples via PEP-646 (#865) --- generate_self_schema.py | 2 +- python/pydantic_core/core_schema.py | 105 +++-- src/serializers/shared.rs | 6 +- src/serializers/type_serializers/tuple.rs | 349 ++++++--------- src/validators/mod.rs | 6 +- src/validators/tuple.rs | 409 ++++++++++++------ src/validators/union.rs | 22 +- tests/benchmarks/complete_schema.py | 14 +- tests/benchmarks/test_micro_benchmarks.py | 8 +- tests/serializers/test_list_tuple.py | 22 +- tests/serializers/test_none.py | 2 +- tests/test_schema_functions.py | 12 +- tests/test_typing.py | 4 +- .../validators/test_definitions_recursive.py | 6 +- tests/validators/test_frozenset.py | 4 +- tests/validators/test_list.py | 4 +- tests/validators/test_set.py | 4 +- tests/validators/test_tuple.py | 124 +++--- tests/validators/test_with_default.py | 18 +- 19 files changed, 603 insertions(+), 518 deletions(-) diff --git a/generate_self_schema.py b/generate_self_schema.py index 4e1235b1a..8d27247d6 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -142,7 +142,7 @@ def type_dict_schema( # noqa: C901 'type': 'union', 'choices': [ schema_ref_validator, - {'type': 'tuple-positional', 'items_schema': [schema_ref_validator, {'type': 'str'}]}, + {'type': 'tuple', 'items_schema': [schema_ref_validator, {'type': 'str'}]}, ], }, } diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 5c54d5cc8..44f58c48a 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1384,16 +1384,7 @@ def list_schema( ) -class TuplePositionalSchema(TypedDict, total=False): - type: Required[Literal['tuple-positional']] - items_schema: Required[List[CoreSchema]] - extras_schema: CoreSchema - strict: bool - ref: str - metadata: Any - serialization: IncExSeqOrElseSerSchema - - +# @deprecated('tuple_positional_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') def tuple_positional_schema( items_schema: list[CoreSchema], *, @@ -1402,7 +1393,7 @@ def tuple_positional_schema( ref: str | None = None, metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, -) -> TuplePositionalSchema: +) -> TupleSchema: """ Returns a schema that matches a tuple of schemas, e.g.: @@ -1427,10 +1418,14 @@ def tuple_positional_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ - return _dict_not_none( - type='tuple-positional', + if extras_schema is not None: + variadic_item_index = len(items_schema) + items_schema = items_schema + [extras_schema] + else: + variadic_item_index = None + return tuple_schema( items_schema=items_schema, - extras_schema=extras_schema, + variadic_item_index=variadic_item_index, strict=strict, ref=ref, metadata=metadata, @@ -1438,9 +1433,55 @@ def tuple_positional_schema( ) -class TupleVariableSchema(TypedDict, total=False): - type: Required[Literal['tuple-variable']] - items_schema: CoreSchema +# @deprecated('tuple_variable_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') +def tuple_variable_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> TupleSchema: + """ + Returns a schema that matches a tuple of a given schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.tuple_variable_schema( + items_schema=core_schema.int_schema(), min_length=0, max_length=10 + ) + v = SchemaValidator(schema) + assert v.validate_python(('1', 2, 3)) == (1, 2, 3) + ``` + + Args: + items_schema: The value must be a tuple with items that match this schema + min_length: The value must be a tuple with at least this many items + max_length: The value must be a tuple with at most this many items + strict: The value must be a tuple with exactly this many items + ref: Optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return tuple_schema( + items_schema=[items_schema or any_schema()], + variadic_item_index=0, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TupleSchema(TypedDict, total=False): + type: Required[Literal['tuple']] + items_schema: Required[List[CoreSchema]] + variadic_item_index: int min_length: int max_length: int strict: bool @@ -1449,41 +1490,45 @@ class TupleVariableSchema(TypedDict, total=False): serialization: IncExSeqOrElseSerSchema -def tuple_variable_schema( - items_schema: CoreSchema | None = None, +def tuple_schema( + items_schema: list[CoreSchema], *, + variadic_item_index: int | None = None, min_length: int | None = None, max_length: int | None = None, strict: bool | None = None, ref: str | None = None, metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, -) -> TupleVariableSchema: +) -> TupleSchema: """ - Returns a schema that matches a tuple of a given schema, e.g.: + Returns a schema that matches a tuple of schemas, with an optional variadic item, e.g.: ```py from pydantic_core import SchemaValidator, core_schema - schema = core_schema.tuple_variable_schema( - items_schema=core_schema.int_schema(), min_length=0, max_length=10 + schema = core_schema.tuple_schema( + [core_schema.int_schema(), core_schema.str_schema(), core_schema.float_schema()], + variadic_item_index=1, ) v = SchemaValidator(schema) - assert v.validate_python(('1', 2, 3)) == (1, 2, 3) + assert v.validate_python((1, 'hello', 'world', 1.5)) == (1, 'hello', 'world', 1.5) ``` Args: - items_schema: The value must be a tuple with items that match this schema + items_schema: The value must be a tuple with items that match these schemas + variadic_item_index: The index of the schema in `items_schema` to be treated as variadic (following PEP 646) min_length: The value must be a tuple with at least this many items max_length: The value must be a tuple with at most this many items strict: The value must be a tuple with exactly this many items - ref: optional unique identifier of the schema, used to reference the schema in other places + ref: Optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( - type='tuple-variable', + type='tuple', items_schema=items_schema, + variadic_item_index=variadic_item_index, min_length=min_length, max_length=max_length, strict=strict, @@ -3634,8 +3679,7 @@ def definition_reference_schema( IsSubclassSchema, CallableSchema, ListSchema, - TuplePositionalSchema, - TupleVariableSchema, + TupleSchema, SetSchema, FrozenSetSchema, GeneratorSchema, @@ -3689,8 +3733,7 @@ def definition_reference_schema( 'is-subclass', 'callable', 'list', - 'tuple-positional', - 'tuple-variable', + 'tuple', 'set', 'frozenset', 'generator', diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index cfccc748a..11aac037d 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -140,8 +140,7 @@ combined_serializer! { Union: super::type_serializers::union::UnionSerializer; Literal: super::type_serializers::literal::LiteralSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; - TuplePositional: super::type_serializers::tuple::TuplePositionalSerializer; - TupleVariable: super::type_serializers::tuple::TupleVariableSerializer; + Tuple: super::type_serializers::tuple::TupleSerializer; } } @@ -248,8 +247,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Union(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Literal(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), - CombinedSerializer::TuplePositional(inner) => inner.py_gc_traverse(visit), - CombinedSerializer::TupleVariable(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), } } diff --git a/src/serializers/type_serializers/tuple.rs b/src/serializers/type_serializers/tuple.rs index 00c61e250..e5f225c92 100644 --- a/src/serializers/type_serializers/tuple.rs +++ b/src/serializers/type_serializers/tuple.rs @@ -2,154 +2,29 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; use std::borrow::Cow; +use std::iter; use serde::ser::SerializeSeq; use crate::definitions::DefinitionsBuilder; +use crate::serializers::type_serializers::any::AnySerializer; use crate::tools::SchemaDict; -use super::any::AnySerializer; use super::{ infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra, PydanticSerializer, SchemaFilter, SerMode, TypeSerializer, }; #[derive(Debug, Clone)] -pub struct TupleVariableSerializer { - item_serializer: Box, +pub struct TupleSerializer { + serializers: Vec, + variadic_item_index: Option, filter: SchemaFilter, name: String, } -impl BuildSerializer for TupleVariableSerializer { - const EXPECTED_TYPE: &'static str = "tuple-variable"; - - fn build( - schema: &PyDict, - config: Option<&PyDict>, - definitions: &mut DefinitionsBuilder, - ) -> PyResult { - let py = schema.py(); - if let Some("positional") = schema.get_as::<&str>(intern!(py, "mode"))? { - return TuplePositionalSerializer::build(schema, config, definitions); - } - let item_serializer = match schema.get_as::<&PyDict>(intern!(py, "items_schema"))? { - Some(items_schema) => CombinedSerializer::build(items_schema, config, definitions)?, - None => AnySerializer::build(schema, config, definitions)?, - }; - let name = format!("tuple[{}, ...]", item_serializer.get_name()); - Ok(Self { - item_serializer: Box::new(item_serializer), - filter: SchemaFilter::from_schema(schema)?, - name, - } - .into()) - } -} - -impl_py_gc_traverse!(TupleVariableSerializer { item_serializer }); - -impl TypeSerializer for TupleVariableSerializer { - fn to_python( - &self, - value: &PyAny, - include: Option<&PyAny>, - exclude: Option<&PyAny>, - extra: &Extra, - ) -> PyResult { - match value.downcast::() { - Ok(py_tuple) => { - let py = value.py(); - let item_serializer = self.item_serializer.as_ref(); - - let mut items = Vec::with_capacity(py_tuple.len()); - for (index, element) in py_tuple.iter().enumerate() { - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len()))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(item_serializer.to_python(element, next_include, next_exclude, extra)?); - } - } - match extra.mode { - SerMode::Json => Ok(PyList::new(py, items).into_py(py)), - _ => Ok(PyTuple::new(py, items).into_py(py)), - } - } - Err(_) => { - extra.warnings.on_fallback_py(&self.name, value, extra)?; - infer_to_python(value, include, exclude, extra) - } - } - } - - fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { - match key.downcast::() { - Ok(py_tuple) => { - let item_serializer = self.item_serializer.as_ref(); - - let mut key_builder = KeyBuilder::new(); - for element in py_tuple { - key_builder.push(&item_serializer.json_key(element, extra)?); - } - Ok(Cow::Owned(key_builder.finish())) - } - Err(_) => { - extra.warnings.on_fallback_py(&self.name, key, extra)?; - infer_json_key(key, extra) - } - } - } - - fn serde_serialize( - &self, - value: &PyAny, - serializer: S, - include: Option<&PyAny>, - exclude: Option<&PyAny>, - extra: &Extra, - ) -> Result { - match value.downcast::() { - Ok(py_tuple) => { - let py_tuple: &PyTuple = py_tuple.downcast().map_err(py_err_se_err)?; - let item_serializer = self.item_serializer.as_ref(); - - let mut seq = serializer.serialize_seq(Some(py_tuple.len()))?; - for (index, element) in py_tuple.iter().enumerate() { - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len())) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = - PydanticSerializer::new(element, item_serializer, next_include, next_exclude, extra); - seq.serialize_element(&item_serialize)?; - } - } - seq.end() - } - Err(_) => { - extra.warnings.on_fallback_ser::(&self.name, value, extra)?; - infer_serialize(value, serializer, include, exclude, extra) - } - } - } - - fn get_name(&self) -> &str { - &self.name - } -} - -#[derive(Debug, Clone)] -pub struct TuplePositionalSerializer { - items_serializers: Vec, - extra_serializer: Box, - filter: SchemaFilter, - name: String, -} - -impl BuildSerializer for TuplePositionalSerializer { - const EXPECTED_TYPE: &'static str = "tuple-positional"; +impl BuildSerializer for TupleSerializer { + const EXPECTED_TYPE: &'static str = "tuple"; fn build( schema: &PyDict, @@ -158,37 +33,31 @@ impl BuildSerializer for TuplePositionalSerializer { ) -> PyResult { let py = schema.py(); let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?; - - let extra_serializer = match schema.get_as::<&PyDict>(intern!(py, "extras_schema"))? { - Some(extras_schema) => CombinedSerializer::build(extras_schema, config, definitions)?, - None => AnySerializer::build(schema, config, definitions)?, - }; - let items_serializers: Vec = items + let serializers: Vec = items .iter() .map(|item| CombinedSerializer::build(item.downcast()?, config, definitions)) .collect::>()?; - let descr = items_serializers - .iter() - .map(TypeSerializer::get_name) - .collect::>() - .join(", "); + let mut serializer_names = serializers.iter().map(TypeSerializer::get_name).collect::>(); + let variadic_item_index: Option = schema.get_as(intern!(py, "variadic_item_index"))?; + if let Some(variadic_item_index) = variadic_item_index { + serializer_names.insert(variadic_item_index + 1, "..."); + } + let name = format!("tuple[{}]", serializer_names.join(", ")); + Ok(Self { - items_serializers, - extra_serializer: Box::new(extra_serializer), + serializers, + variadic_item_index, filter: SchemaFilter::from_schema(schema)?, - name: format!("tuple[{descr}]"), + name, } .into()) } } -impl_py_gc_traverse!(TuplePositionalSerializer { - items_serializers, - extra_serializer -}); +impl_py_gc_traverse!(TupleSerializer { serializers }); -impl TypeSerializer for TuplePositionalSerializer { +impl TypeSerializer for TupleSerializer { fn to_python( &self, value: &PyAny, @@ -200,31 +69,53 @@ impl TypeSerializer for TuplePositionalSerializer { Ok(py_tuple) => { let py = value.py(); + let n_items = py_tuple.len(); let mut py_tuple_iter = py_tuple.iter(); - let mut items = Vec::with_capacity(py_tuple.len()); - for (index, serializer) in self.items_serializers.iter().enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, + let mut items = Vec::with_capacity(n_items); + + macro_rules! use_serializers { + ($serializers_iter:expr) => { + for (index, serializer) in $serializers_iter.enumerate() { + let element = match py_tuple_iter.next() { + Some(value) => value, + None => break, + }; + let op_next = self + .filter + .index_filter(index, include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + items.push(serializer.to_python(element, next_include, next_exclude, extra)?); + } + } }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len()))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(serializer.to_python(element, next_include, next_exclude, extra)?); - } } - let expected_length = self.items_serializers.len(); - let extra_serializer = self.extra_serializer.as_ref(); - for (index2, element) in py_tuple_iter.enumerate() { - let index = index2 + expected_length; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len()))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(extra_serializer.to_python(element, next_include, next_exclude, extra)?); + + if let Some(variadic_item_index) = self.variadic_item_index { + // Need `saturating_sub` to handle items with too few elements without panicking + let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); + let serializers_iter = self.serializers[..variadic_item_index] + .iter() + .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) + .chain(self.serializers[variadic_item_index + 1..].iter()); + use_serializers!(serializers_iter); + } else { + use_serializers!(self.serializers.iter()); + let mut warned = false; + for (i, element) in py_tuple_iter.enumerate() { + if !warned { + extra + .warnings + .custom_warning("Unexpected extra items present in tuple".to_string()); + warned = true; + } + let op_next = + self.filter + .index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + items.push(AnySerializer.to_python(element, next_include, next_exclude, extra)?); + } } - } + }; match extra.mode { SerMode::Json => Ok(PyList::new(py, items).into_py(py)), @@ -244,17 +135,33 @@ impl TypeSerializer for TuplePositionalSerializer { let mut py_tuple_iter = py_tuple.iter(); let mut key_builder = KeyBuilder::new(); - for serializer in &self.items_serializers { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, + + let n_items = py_tuple.len(); + + macro_rules! use_serializers { + ($serializers_iter:expr) => { + for serializer in $serializers_iter { + let element = match py_tuple_iter.next() { + Some(value) => value, + None => break, + }; + key_builder.push(&serializer.json_key(element, extra)?); + } }; - key_builder.push(&serializer.json_key(element, extra)?); - } - let extra_serializer = self.extra_serializer.as_ref(); - for element in py_tuple_iter { - key_builder.push(&extra_serializer.json_key(element, extra)?); } + + if let Some(variadic_item_index) = self.variadic_item_index { + // Need `saturating_sub` to handle items with too few elements without panicking + let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); + let serializers_iter = self.serializers[..variadic_item_index] + .iter() + .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) + .chain(self.serializers[variadic_item_index + 1..].iter()); + use_serializers!(serializers_iter); + } else { + use_serializers!(self.serializers.iter()); + }; + Ok(Cow::Owned(key_builder.finish())) } Err(_) => { @@ -276,38 +183,64 @@ impl TypeSerializer for TuplePositionalSerializer { Ok(py_tuple) => { let py_tuple: &PyTuple = py_tuple.downcast().map_err(py_err_se_err)?; + let n_items = py_tuple.len(); let mut py_tuple_iter = py_tuple.iter(); - let mut seq = serializer.serialize_seq(Some(py_tuple.len()))?; - for (index, serializer) in self.items_serializers.iter().enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, + let mut seq = serializer.serialize_seq(Some(n_items))?; + + macro_rules! use_serializers { + ($serializers_iter:expr) => { + for (index, serializer) in $serializers_iter.enumerate() { + let element = match py_tuple_iter.next() { + Some(value) => value, + None => break, + }; + let op_next = self + .filter + .index_filter(index, include, exclude, Some(n_items)) + .map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = op_next { + let item_serialize = + PydanticSerializer::new(element, serializer, next_include, next_exclude, extra); + seq.serialize_element(&item_serialize)?; + } + } }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len())) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = - PydanticSerializer::new(element, serializer, next_include, next_exclude, extra); - seq.serialize_element(&item_serialize)?; - } } - let expected_length = self.items_serializers.len(); - let extra_serializer = self.extra_serializer.as_ref(); - for (index2, element) in py_tuple_iter.enumerate() { - let index = index2 + expected_length; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(py_tuple.len())) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = - PydanticSerializer::new(element, extra_serializer, next_include, next_exclude, extra); - seq.serialize_element(&item_serialize)?; + if let Some(variadic_item_index) = self.variadic_item_index { + // Need `saturating_sub` to handle items with too few elements without panicking + let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); + let serializers_iter = self.serializers[..variadic_item_index] + .iter() + .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) + .chain(self.serializers[variadic_item_index + 1..].iter()); + use_serializers!(serializers_iter); + } else { + use_serializers!(self.serializers.iter()); + let mut warned = false; + for (i, element) in py_tuple_iter.enumerate() { + if !warned { + extra + .warnings + .custom_warning("Unexpected extra items present in tuple".to_string()); + warned = true; + } + let op_next = self + .filter + .index_filter(i + self.serializers.len(), include, exclude, Some(n_items)) + .map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = op_next { + let item_serialize = PydanticSerializer::new( + element, + &CombinedSerializer::Any(AnySerializer), + next_include, + next_exclude, + extra, + ); + seq.serialize_element(&item_serialize)?; + } } - } + }; seq.end() } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 7809b1ee3..fc1584080 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -491,8 +491,7 @@ pub fn build_validator<'a>( // decimals decimal::DecimalValidator, // tuples - tuple::TuplePositionalValidator, - tuple::TupleVariableValidator, + tuple::TupleValidator, // list/arrays list::ListValidator, // sets - unique lists @@ -639,8 +638,7 @@ pub enum CombinedValidator { // sets - unique lists Set(set::SetValidator), // tuples - TuplePositional(tuple::TuplePositionalValidator), - TupleVariable(tuple::TupleVariableValidator), + Tuple(tuple::TupleValidator), // dicts/objects (recursive) Dict(dict::DictValidator), // None/null diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index af134c8ee..db9c57953 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -1,39 +1,52 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; +use std::collections::VecDeque; use crate::build_tools::is_strict; -use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericIterable, Input}; use crate::tools::SchemaDict; use crate::validators::Exactness; -use super::list::{get_items_schema, min_length_check}; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug)] -pub struct TupleVariableValidator { +pub struct TupleValidator { strict: bool, - item_validator: Option>, + validators: Vec, + variadic_item_index: Option, min_length: Option, max_length: Option, name: String, } -impl BuildValidator for TupleVariableValidator { - const EXPECTED_TYPE: &'static str = "tuple-variable"; +impl BuildValidator for TupleValidator { + const EXPECTED_TYPE: &'static str = "tuple"; fn build( schema: &PyDict, config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { let py = schema.py(); - let item_validator = get_items_schema(schema, config, definitions)?.map(Box::new); - let inner_name = item_validator.as_ref().map_or("any", |v| v.get_name()); - let name = format!("tuple[{inner_name}, ...]"); + let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?; + let validators: Vec = items + .iter() + .map(|item| build_validator(item, config, definitions)) + .collect::>()?; + + let mut validator_names = validators.iter().map(Validator::get_name).collect::>(); + let variadic_item_index: Option = schema.get_as(intern!(py, "variadic_item_index"))?; + // FIXME add friendly schema error if item out of bounds + if let Some(variadic_item_index) = variadic_item_index { + validator_names.insert(variadic_item_index + 1, "..."); + } + let name = format!("tuple[{}]", validator_names.join(", ")); + Ok(Self { strict: is_strict(schema, config)?, - item_validator, + validators, + variadic_item_index, min_length: schema.get_as(intern!(py, "min_length"))?, max_length: schema.get_as(intern!(py, "max_length"))?, name, @@ -42,191 +55,276 @@ impl BuildValidator for TupleVariableValidator { } } -impl_py_gc_traverse!(TupleVariableValidator { item_validator }); +impl_py_gc_traverse!(TupleValidator { validators }); -impl Validator for TupleVariableValidator { - fn validate<'data>( +impl TupleValidator { + #[allow(clippy::too_many_arguments)] + fn validate_tuple_items<'s, 'data, I: Input<'data> + 'data>( &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, - ) -> ValResult { - let seq = input.validate_tuple(state.strict_or(self.strict))?; - let exactness = match &seq { - GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, - GenericIterable::List(_) => Exactness::Strict, - _ => Exactness::Lax, - }; - state.floor_exactness(exactness); + output: &mut Vec, + errors: &mut Vec, + item_validators: &[CombinedValidator], + collection_iter: &mut NextCountingIterator>, + actual_length: Option, + ) -> ValResult<()> { + // Validate the head: + for validator in item_validators { + match collection_iter.next() { + Some((index, input_item)) => match validator.validate(py, input_item, state) { + Ok(item) => self.push_output_item(input, output, item, actual_length)?, + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + } + Err(ValError::Omit) => (), + Err(err) => return Err(err), + }, + None => { + let index = collection_iter.next_calls() - 1; + if let Some(value) = validator.default_value(py, Some(index), state)? { + output.push(value); + } else { + errors.push(ValLineError::new_with_loc(ErrorTypeDefaults::Missing, input, index)); + } + } + } + } - let output = match self.item_validator { - Some(ref v) => seq.validate_to_vec(py, input, self.max_length, "Tuple", v, state)?, - None => seq.to_vec(py, input, "Tuple", self.max_length)?, - }; - min_length_check!(input, "Tuple", self.min_length, output); - Ok(PyTuple::new(py, &output).into_py(py)) + Ok(()) } - fn get_name(&self) -> &str { - &self.name - } -} + #[allow(clippy::too_many_arguments)] + fn validate_tuple_variable<'data, I: Input<'data> + 'data, InputT: Input<'data> + 'data>( + &self, + py: Python<'data>, + input: &'data InputT, + state: &mut ValidationState, + errors: &mut Vec, + collection_iter: &mut NextCountingIterator>, + actual_length: Option, + ) -> ValResult> { + let expected_length = if self.variadic_item_index.is_some() { + actual_length.unwrap_or(self.validators.len()) + } else { + self.validators.len() + }; + let mut output = Vec::with_capacity(expected_length); + if let Some(variable_validator_index) = self.variadic_item_index { + let (head_validators, [variable_validator, tail_validators @ ..]) = + self.validators.split_at(variable_validator_index) + else { + unreachable!("validators will always contain variable validator") + }; -#[derive(Debug)] -pub struct TuplePositionalValidator { - strict: bool, - items_validators: Vec, - extras_validator: Option>, - name: String, -} + // Validate the "head" items + self.validate_tuple_items( + py, + input, + state, + &mut output, + errors, + head_validators, + collection_iter, + actual_length, + )?; -impl BuildValidator for TuplePositionalValidator { - const EXPECTED_TYPE: &'static str = "tuple-positional"; - fn build( - schema: &PyDict, - config: Option<&PyDict>, - definitions: &mut DefinitionsBuilder, - ) -> PyResult { - let py = schema.py(); - let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?; - let validators: Vec = items - .iter() - .map(|item| build_validator(item, config, definitions)) - .collect::>()?; + let n_tail_validators = tail_validators.len(); + if n_tail_validators == 0 { + for (index, input_item) in collection_iter { + match variable_validator.validate(py, input_item, state) { + Ok(item) => self.push_output_item(input, &mut output, item, actual_length)?, + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + } + Err(ValError::Omit) => (), + Err(err) => return Err(err), + } + } + } else { + // Populate a buffer with the first n_tail_validators items + // NB: We take from collection_iter.inner to avoid increasing the next calls count + // while populating the buffer. This means the index in the following loop is the + // right one for user errors. + let mut tail_buffer: VecDeque<&'data I> = + collection_iter.inner.by_ref().take(n_tail_validators).collect(); - let descr = validators - .iter() - .map(Validator::get_name) - .collect::>() - .join(", "); - Ok(Self { - strict: is_strict(schema, config)?, - items_validators: validators, - extras_validator: match schema.get_item(intern!(py, "extras_schema"))? { - Some(v) => Some(Box::new(build_validator(v, config, definitions)?)), - None => None, - }, - name: format!("tuple[{descr}]"), - } - .into()) - } -} + // Save the current index for the tail validation below when we recreate a new NextCountingIterator + let mut index = collection_iter.next_calls(); -#[allow(clippy::too_many_arguments)] -fn validate_tuple_positional<'s, 'data, T: Iterator>, I: Input<'data> + 'data>( - py: Python<'data>, - input: &'data impl Input<'data>, - state: &mut ValidationState, - output: &mut Vec, - errors: &mut Vec, - extras_validator: &Option>, - items_validators: &[CombinedValidator], - collection_iter: &mut T, - actual_length: Option, -) -> ValResult<()> { - for (index, validator) in items_validators.iter().enumerate() { - match collection_iter.next() { - Some(result) => match validator.validate(py, result?, state) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); - } - Err(err) => return Err(err), - }, - None => { - if let Some(value) = validator.default_value(py, Some(index), state)? { - output.push(value); - } else { - errors.push(ValLineError::new_with_loc(ErrorTypeDefaults::Missing, input, index)); + // Iterate over all remaining collection items, validating as items "leave" the buffer + for (buffer_item_index, input_item) in collection_iter { + index = buffer_item_index; + // This `unwrap` is safe because you can only get here + // if there were at least `n_tail_validators` (> 0) items in the iterator + let buffered_item = tail_buffer.pop_front().unwrap(); + tail_buffer.push_back(input_item); + + match variable_validator.validate(py, buffered_item, state) { + Ok(item) => self.push_output_item(input, &mut output, item, actual_length)?, + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors + .into_iter() + .map(|err| err.with_outer_location(buffer_item_index.into())), + ); + } + Err(ValError::Omit) => (), + Err(err) => return Err(err), + } } + + // Validate the buffered items using the tail validators + self.validate_tuple_items( + py, + input, + state, + &mut output, + errors, + tail_validators, + &mut NextCountingIterator::new(tail_buffer.into_iter(), index), + actual_length, + )?; + } + } else { + // Validate all items as positional + self.validate_tuple_items( + py, + input, + state, + &mut output, + errors, + &self.validators, + collection_iter, + actual_length, + )?; + + // Generate an error if there are any extra items: + if collection_iter.next().is_some() { + return Err(ValError::new( + ErrorType::TooLong { + field_type: "Tuple".to_string(), + max_length: self.validators.len(), + actual_length, + context: None, + }, + input, + )); } } + Ok(output) } - for (index, result) in collection_iter.enumerate() { - let item = result?; - match extras_validator { - Some(ref extras_validator) => match extras_validator.validate(py, item, state) { - Ok(item) => output.push(item), - Err(ValError::LineErrors(line_errors)) => { - errors.extend( - line_errors - .into_iter() - .map(|err| err.with_outer_location((index + items_validators.len()).into())), - ); - } - Err(ValError::Omit) => (), - Err(err) => return Err(err), - }, - None => { - errors.push(ValLineError::new( + + fn push_output_item<'data>( + &self, + input: &'data impl Input<'data>, + output: &mut Vec, + item: PyObject, + actual_length: Option, + ) -> ValResult<()> { + output.push(item); + if let Some(max_length) = self.max_length { + if output.len() > max_length { + return Err(ValError::new( ErrorType::TooLong { field_type: "Tuple".to_string(), - max_length: items_validators.len(), + max_length, actual_length, context: None, }, input, )); - // no need to continue through further items - break; } } + Ok(()) } - Ok(()) } -impl_py_gc_traverse!(TuplePositionalValidator { - items_validators, - extras_validator -}); - -impl Validator for TuplePositionalValidator { +impl Validator for TupleValidator { fn validate<'data>( &self, py: Python<'data>, input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult { - let collection = input.validate_tuple(state.strict_or(self.strict))?; - let exactness: crate::validators::Exactness = match &collection { + let collection: GenericIterable<'_> = input.validate_tuple(state.strict_or(self.strict))?; + let exactness = match &collection { GenericIterable::Tuple(_) | GenericIterable::JsonArray(_) => Exactness::Exact, GenericIterable::List(_) => Exactness::Strict, _ => Exactness::Lax, }; state.floor_exactness(exactness); - let actual_length = collection.generic_len(); - let expected_length = if self.extras_validator.is_some() { - actual_length.unwrap_or(self.items_validators.len()) - } else { - self.items_validators.len() - }; - let mut output: Vec = Vec::with_capacity(expected_length); let mut errors: Vec = Vec::new(); + let mut iteration_error = None; + macro_rules! iter { - ($collection_iter:expr) => {{ - validate_tuple_positional( + ($collection_iter:expr) => { + self.validate_tuple_variable( py, input, state, - &mut output, &mut errors, - &self.extras_validator, - &self.items_validators, - &mut $collection_iter, + &mut NextCountingIterator::new($collection_iter, 0), actual_length, - )? - }}; + ) + }; } - match collection { - GenericIterable::List(collection_iter) => iter!(collection_iter.iter().map(Ok)), - GenericIterable::Tuple(collection_iter) => iter!(collection_iter.iter().map(Ok)), - GenericIterable::JsonArray(collection_iter) => iter!(collection_iter.iter().map(Ok)), - other => iter!(other.as_sequence_iterator(py)?), + let output = match collection { + GenericIterable::List(collection_iter) => iter!(collection_iter.iter())?, + GenericIterable::Tuple(collection_iter) => iter!(collection_iter.iter())?, + GenericIterable::JsonArray(collection_iter) => iter!(collection_iter.iter())?, + other => iter!({ + let mut sequence_iterator = other.as_sequence_iterator(py)?; + let iteration_error = &mut iteration_error; + let mut index: usize = 0; + std::iter::from_fn(move || { + if iteration_error.is_some() { + return None; + } + index += 1; + match sequence_iterator.next() { + Some(Ok(item)) => Some(item), + Some(Err(e)) => { + *iteration_error = Some(ValError::new_with_loc( + ErrorType::IterationError { + error: py_err_string(py, e), + context: None, + }, + input, + index, + )); + None + } + None => None, + } + }) + })?, + }; + + if let Some(err) = iteration_error { + return Err(err); + } + + if let Some(min_length) = self.min_length { + let actual_length = output.len(); + if actual_length < min_length { + errors.push(ValLineError::new( + ErrorType::TooShort { + field_type: "Tuple".to_string(), + min_length, + actual_length, + context: None, + }, + input, + )); + } } + if errors.is_empty() { Ok(PyTuple::new(py, &output).into_py(py)) } else { @@ -238,3 +336,28 @@ impl Validator for TuplePositionalValidator { &self.name } } + +struct NextCountingIterator { + inner: I, + count: usize, +} + +impl NextCountingIterator { + fn new(inner: I, count: usize) -> Self { + Self { inner, count } + } + + fn next_calls(&self) -> usize { + self.count + } +} + +impl Iterator for NextCountingIterator { + type Item = (usize, I::Item); + + fn next(&mut self) -> Option { + let count = self.count; + self.count += 1; + self.inner.next().map(|item| (count, item)) + } +} diff --git a/src/validators/union.rs b/src/validators/union.rs index 0148e1287..e010348af 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -457,7 +457,7 @@ impl TaggedUnionValidator { let tag_cow = either_tag.as_cow()?; let tag = tag_cow.as_ref(); // custom logic to distinguish between different function and tuple schemas - if tag == "function" || tag == "tuple" { + if tag == "function" { let mode = match dict { GenericMapping::PyDict(dict) => match dict.get_item(intern!(py, "mode"))? { Some(m) => Some(m.validate_str(true, false)?.into_inner()), @@ -465,21 +465,11 @@ impl TaggedUnionValidator { }, _ => unreachable!(), }; - if tag == "function" { - let mode = mode.ok_or_else(|| self.tag_not_found(input))?; - match mode.as_cow()?.as_ref() { - "plain" => Ok(intern!(py, "function-plain")), - "wrap" => Ok(intern!(py, "function-wrap")), - _ => Ok(intern!(py, "function")), - } - } else { - // tag == "tuple" - if let Some(mode) = mode { - if mode.as_cow()?.as_ref() == "positional" { - return Ok(intern!(py, "tuple-positional")); - } - } - Ok(intern!(py, "tuple-variable")) + let mode = mode.ok_or_else(|| self.tag_not_found(input))?; + match mode.as_cow()?.as_ref() { + "plain" => Ok(intern!(py, "function-plain")), + "wrap" => Ok(intern!(py, "function-wrap")), + _ => Ok(intern!(py, "function")), } } else { Ok(PyString::new(py, tag)) diff --git a/tests/benchmarks/complete_schema.py b/tests/benchmarks/complete_schema.py index 8c24b9b7e..cfc522d29 100644 --- a/tests/benchmarks/complete_schema.py +++ b/tests/benchmarks/complete_schema.py @@ -83,16 +83,20 @@ def wrap_function(input_value, validator, info): 'max_length': 42, }, }, - 'field_tuple_var_len_any': {'type': 'model-field', 'schema': {'type': 'tuple-variable'}}, + 'field_tuple_var_len_any': { + 'type': 'model-field', + 'schema': {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, + }, 'field_tuple_var_len_float': { 'type': 'model-field', - 'schema': {'type': 'tuple-variable', 'items_schema': {'type': 'float'}}, + 'schema': {'type': 'tuple', 'items_schema': [{'type': 'float'}], 'variadic_item_index': 0}, }, 'field_tuple_var_len_float_con': { 'type': 'model-field', 'schema': { - 'type': 'tuple-variable', - 'items_schema': {'type': 'float'}, + 'type': 'tuple', + 'items_schema': [{'type': 'float'}], + 'variadic_item_index': 0, 'min_length': 3, 'max_length': 42, }, @@ -100,7 +104,7 @@ def wrap_function(input_value, validator, info): 'field_tuple_fix_len': { 'type': 'model-field', 'schema': { - 'type': 'tuple-positional', + 'type': 'tuple', 'items_schema': [{'type': 'str'}, {'type': 'int'}, {'type': 'float'}, {'type': 'bool'}], }, }, diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index cb08436c6..f1ec32eef 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -848,7 +848,7 @@ def t(): def test_positional_tuple(benchmark): v = SchemaValidator( { - 'type': 'tuple-positional', + 'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}, {'type': 'int'}, {'type': 'int'}], } ) @@ -859,7 +859,7 @@ def test_positional_tuple(benchmark): @pytest.mark.benchmark(group='tuple') def test_variable_tuple(benchmark): - v = SchemaValidator({'type': 'tuple-variable', 'items_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) assert v.validate_python((1, 2, 3, '4', 5)) == (1, 2, 3, 4, 5) benchmark(v.validate_python, (1, 2, 3, '4', 5)) @@ -867,7 +867,7 @@ def test_variable_tuple(benchmark): @pytest.mark.benchmark(group='tuple-many') def test_tuple_many_variable(benchmark): - v = SchemaValidator({'type': 'tuple-variable', 'items_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) assert v.validate_python(list(range(10))) == tuple(range(10)) benchmark(v.validate_python, list(range(10))) @@ -875,7 +875,7 @@ def test_tuple_many_variable(benchmark): @pytest.mark.benchmark(group='tuple-many') def test_tuple_many_positional(benchmark): - v = SchemaValidator({'type': 'tuple-positional', 'items_schema': [], 'extras_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) assert v.validate_python(list(range(10))) == tuple(range(10)) benchmark(v.validate_python, list(range(10))) diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index c695bbcc4..df6aabd8a 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -329,19 +329,20 @@ def test_filter_list_of_dicts(): def test_positional_tuple(): - s = SchemaSerializer( - {'type': 'tuple-positional', 'items_schema': [{'type': 'int'}, {'type': 'bytes'}, {'type': 'float'}]} - ) + s = SchemaSerializer({'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'bytes'}, {'type': 'float'}]}) assert s.to_python((1, b'2', 3.0)) == (1, b'2', 3.0) - assert s.to_python((1, b'2', 3.0, 123)) == (1, b'2', 3.0, 123) + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + assert s.to_python((1, b'2', 3.0, 123)) == (1, b'2', 3.0, 123) assert s.to_python((1, b'2')) == (1, b'2') assert s.to_python((1, b'2', 3.0), mode='json') == [1, '2', 3.0] - assert s.to_python((1, b'2', 3.0, 123), mode='json') == [1, '2', 3.0, 123] + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + assert s.to_python((1, b'2', 3.0, 123), mode='json') == [1, '2', 3.0, 123] assert s.to_python((1, b'2'), mode='json') == [1, '2'] assert s.to_json((1, b'2', 3.0)) == b'[1,"2",3.0]' - assert s.to_json((1, b'2', 3.0, 123)) == b'[1,"2",3.0,123]' + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + assert s.to_json((1, b'2', 3.0, 123)) == b'[1,"2",3.0,123]' assert s.to_json((1, b'2')) == b'[1,"2"]' @@ -351,7 +352,7 @@ def f(prefix, value, _info): s = SchemaSerializer( { - 'type': 'tuple-positional', + 'type': 'tuple', 'items_schema': [ core_schema.any_schema( serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'a'), info_arg=True) @@ -359,10 +360,11 @@ def f(prefix, value, _info): core_schema.any_schema( serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'b'), info_arg=True) ), + core_schema.any_schema( + serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'extra'), info_arg=True) + ), ], - 'extras_schema': core_schema.any_schema( - serialization=core_schema.plain_serializer_function_ser_schema(partial(f, 'extra'), info_arg=True) - ), + 'variadic_item_index': 2, } ) assert s.to_python((1,)) == ('a1',) diff --git a/tests/serializers/test_none.py b/tests/serializers/test_none.py index 6bbd49c46..e5e1a9093 100644 --- a/tests/serializers/test_none.py +++ b/tests/serializers/test_none.py @@ -16,7 +16,7 @@ 'url', 'multi-host-url', ) -all_types = all_scalars + ('list', 'tuple-variable', 'dict', 'set', 'frozenset') +all_types = all_scalars + ('list', 'dict', 'set', 'frozenset') @pytest.mark.parametrize('schema_type', all_types) diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index d4b53cbe4..2a6032174 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -79,17 +79,7 @@ def args(*args, **kwargs): (core_schema.callable_schema, args(), {'type': 'callable'}), (core_schema.list_schema, args(), {'type': 'list'}), (core_schema.list_schema, args({'type': 'int'}), {'type': 'list', 'items_schema': {'type': 'int'}}), - ( - core_schema.tuple_positional_schema, - args([{'type': 'int'}]), - {'type': 'tuple-positional', 'items_schema': [{'type': 'int'}]}, - ), - (core_schema.tuple_positional_schema, args([]), {'type': 'tuple-positional', 'items_schema': []}), - ( - core_schema.tuple_variable_schema, - args({'type': 'int'}), - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}}, - ), + (core_schema.tuple_schema, args([]), {'type': 'tuple', 'items_schema': []}), ( core_schema.set_schema, args({'type': 'int'}, min_length=4), diff --git a/tests/test_typing.py b/tests/test_typing.py index dcd2f267a..fb9f3e949 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -77,9 +77,9 @@ def test_schema_typing() -> None: SchemaValidator(schema) schema: CoreSchema = {'type': 'set', 'items_schema': {'type': 'str'}, 'max_length': 3} SchemaValidator(schema) - schema: CoreSchema = {'type': 'tuple-variable', 'items_schema': {'type': 'str'}, 'max_length': 3} + schema: CoreSchema = {'type': 'tuple', 'items_schema': [{'type': 'str'}], 'variadic_item_index': 0, 'max_length': 3} SchemaValidator(schema) - schema: CoreSchema = {'type': 'tuple-positional', 'items_schema': [{'type': 'str'}, {'type': 'int'}]} + schema: CoreSchema = {'type': 'tuple', 'items_schema': [{'type': 'str'}, {'type': 'int'}]} SchemaValidator(schema) schema: CoreSchema = {'type': 'frozenset', 'items_schema': {'type': 'str'}, 'max_length': 3} SchemaValidator(schema) diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 9f7a93d57..1d7994054 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -279,7 +279,7 @@ def test_outside_parent(): } ), [ - core_schema.tuple_positional_schema( + core_schema.tuple_schema( [core_schema.int_schema(), core_schema.int_schema(), core_schema.str_schema()], ref='tuple-iis' ) ], @@ -426,7 +426,7 @@ def multiple_tuple_schema() -> SchemaValidator: } ), [ - core_schema.tuple_positional_schema( + core_schema.tuple_schema( [ core_schema.int_schema(), core_schema.nullable_schema(core_schema.definition_reference_schema('t')), @@ -522,7 +522,7 @@ def wrap_func(input_value, validator, info): [ core_schema.with_info_wrap_validator_function( wrap_func, - core_schema.tuple_positional_schema( + core_schema.tuple_schema( [ core_schema.int_schema(), core_schema.nullable_schema(core_schema.definition_reference_schema('wrapper')), diff --git a/tests/validators/test_frozenset.py b/tests/validators/test_frozenset.py index 9ec2f10d5..6ee297b22 100644 --- a/tests/validators/test_frozenset.py +++ b/tests/validators/test_frozenset.py @@ -276,13 +276,13 @@ def gen(error: bool): [ pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'any'}}, + {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, frozenset(((1, 10), (2, 20), ('3', '30'))), id='Tuple[Any, Any]', ), pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}}, + {'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}, frozenset(((1, 10), (2, 20), (3, 30))), id='Tuple[int, int]', ), diff --git a/tests/validators/test_list.py b/tests/validators/test_list.py index 6f5ca8a34..187c0b851 100644 --- a/tests/validators/test_list.py +++ b/tests/validators/test_list.py @@ -305,13 +305,13 @@ def gen(error: bool): [ pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'any'}}, + {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, [(1, 10), (2, 20), ('3', '30')], id='Tuple[Any, Any]', ), pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}}, + {'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}, [(1, 10), (2, 20), (3, 30)], id='Tuple[int, int]', ), diff --git a/tests/validators/test_set.py b/tests/validators/test_set.py index a6babb80d..e58cedff8 100644 --- a/tests/validators/test_set.py +++ b/tests/validators/test_set.py @@ -242,13 +242,13 @@ def gen(error: bool): [ pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'any'}}, + {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, {(1, 10), (2, 20), ('3', '30')}, id='Tuple[Any, Any]', ), pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}}, + {'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}, {(1, 10), (2, 20), (3, 30)}, id='Tuple[int, int]', ), diff --git a/tests/validators/test_tuple.py b/tests/validators/test_tuple.py index 023a2b48e..e62f49f1d 100644 --- a/tests/validators/test_tuple.py +++ b/tests/validators/test_tuple.py @@ -5,19 +5,19 @@ import pytest from dirty_equals import IsNonNegative, IsTuple -from pydantic_core import SchemaValidator, ValidationError +from pydantic_core import SchemaValidator, ValidationError, core_schema from ..conftest import Err, PyAndJson, infinite_generator @pytest.mark.parametrize( - 'mode,items,input_value,expected', + 'variadic_item_index,items,input_value,expected', [ - ('variable', {'type': 'int'}, [1, 2, 3], (1, 2, 3)), - ('variable', {'type': 'int'}, 1, Err('[type=tuple_type, input_value=1, input_type=int]')), - ('positional', [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}], [1, 2, '3'], (1, 2, 3)), + (0, [{'type': 'int'}], [1, 2, 3], (1, 2, 3)), + (0, [{'type': 'int'}], 1, Err('[type=tuple_type, input_value=1, input_type=int]')), + (None, [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}], [1, 2, '3'], (1, 2, 3)), ( - 'positional', + None, [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}], 5, Err('[type=tuple_type, input_value=5, input_type=int]'), @@ -25,8 +25,8 @@ ], ids=repr, ) -def test_tuple_json(py_and_json: PyAndJson, mode, items, input_value, expected): - v = py_and_json({'type': f'tuple-{mode}', 'items_schema': items}) +def test_tuple_json(py_and_json: PyAndJson, variadic_item_index, items, input_value, expected): + v = py_and_json(core_schema.tuple_schema(items_schema=items, variadic_item_index=variadic_item_index)) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) @@ -35,7 +35,7 @@ def test_tuple_json(py_and_json: PyAndJson, mode, items, input_value, expected): def test_any_no_copy(): - v = SchemaValidator({'type': 'tuple-variable'}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}) input_value = (1, '2', b'3') output = v.validate_python(input_value) assert output == input_value @@ -44,20 +44,22 @@ def test_any_no_copy(): @pytest.mark.parametrize( - 'mode,items,input_value,expected', + 'variadic_item_index,items,input_value,expected', [ - ('variable', {'type': 'int'}, (1, 2, '33'), (1, 2, 33)), - ('variable', {'type': 'str'}, (b'1', b'2', '33'), ('1', '2', '33')), - ('positional', [{'type': 'int'}, {'type': 'str'}, {'type': 'float'}], (1, b'a', 33), (1, 'a', 33.0)), + (0, [{'type': 'int'}], (1, 2, '33'), (1, 2, 33)), + (0, [{'type': 'str'}], (b'1', b'2', '33'), ('1', '2', '33')), + (None, [{'type': 'int'}, {'type': 'str'}, {'type': 'float'}], (1, b'a', 33), (1, 'a', 33.0)), ], ) -def test_tuple_strict_passes_with_tuple(mode, items, input_value, expected): - v = SchemaValidator({'type': f'tuple-{mode}', 'items_schema': items, 'strict': True}) +def test_tuple_strict_passes_with_tuple(variadic_item_index, items, input_value, expected): + v = SchemaValidator( + core_schema.tuple_schema(items_schema=items, variadic_item_index=variadic_item_index, strict=True) + ) assert v.validate_python(input_value) == expected def test_empty_positional_tuple(): - v = SchemaValidator({'type': 'tuple-positional', 'items_schema': []}) + v = SchemaValidator({'type': 'tuple', 'items_schema': []}) assert v.validate_python(()) == () assert v.validate_python([]) == () with pytest.raises(ValidationError) as exc_info: @@ -76,11 +78,13 @@ def test_empty_positional_tuple(): @pytest.mark.parametrize( - 'mode,items', [('variable', {'type': 'int'}), ('positional', [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] + 'variadic_item_index,items', [(0, [{'type': 'int'}]), (None, [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] ) @pytest.mark.parametrize('wrong_coll_type', [list, set, frozenset]) -def test_tuple_strict_fails_without_tuple(wrong_coll_type: Type[Any], mode, items): - v = SchemaValidator({'type': f'tuple-{mode}', 'items_schema': items, 'strict': True}) +def test_tuple_strict_fails_without_tuple(wrong_coll_type: Type[Any], variadic_item_index, items): + v = SchemaValidator( + core_schema.tuple_schema(variadic_item_index=variadic_item_index, items_schema=items, strict=True) + ) with pytest.raises(ValidationError) as exc_info: v.validate_python(wrong_coll_type([1, 2, '33'])) assert exc_info.value.errors(include_url=False) == [ @@ -119,7 +123,7 @@ def test_tuple_strict_fails_without_tuple(wrong_coll_type: Type[Any], mode, item ids=repr, ) def test_tuple_var_len_kwargs(kwargs: Dict[str, Any], input_value, expected): - v = SchemaValidator({'type': 'tuple-variable', **kwargs}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0, **kwargs}) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) @@ -128,7 +132,7 @@ def test_tuple_var_len_kwargs(kwargs: Dict[str, Any], input_value, expected): @pytest.mark.parametrize( - 'mode,items', [('variable', {'type': 'int'}), ('positional', [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] + 'variadic_item_index,items', [(0, [{'type': 'int'}]), (None, [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] ) @pytest.mark.parametrize( 'input_value,expected', @@ -144,8 +148,8 @@ def test_tuple_var_len_kwargs(kwargs: Dict[str, Any], input_value, expected): ], ids=repr, ) -def test_tuple_validate(input_value, expected, mode, items): - v = SchemaValidator({'type': f'tuple-{mode}', 'items_schema': items}) +def test_tuple_validate(input_value, expected, variadic_item_index, items): + v = SchemaValidator(core_schema.tuple_schema(items_schema=items, variadic_item_index=variadic_item_index)) if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_python(input_value) @@ -157,10 +161,10 @@ def test_tuple_validate(input_value, expected, mode, items): # on the first test run. This is a workaround to make sure the generator is # always recreated. @pytest.mark.parametrize( - 'mode,items', [('variable', {'type': 'int'}), ('positional', [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] + 'variadic_item_index,items', [(0, [{'type': 'int'}]), (None, [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}])] ) -def test_tuple_validate_iterator(mode, items): - v = SchemaValidator({'type': f'tuple-{mode}', 'items_schema': items}) +def test_tuple_validate_iterator(variadic_item_index, items): + v = SchemaValidator(core_schema.tuple_schema(items_schema=items, variadic_item_index=variadic_item_index)) assert v.validate_python((x for x in [1, 2, '3'])) == (1, 2, 3) @@ -175,7 +179,7 @@ def test_tuple_validate_iterator(mode, items): ], ) def test_tuple_var_len_errors(input_value, index): - v = SchemaValidator({'type': 'tuple-variable', 'items_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) with pytest.raises(ValidationError) as exc_info: assert v.validate_python(input_value) assert exc_info.value.errors(include_url=False) == [ @@ -203,7 +207,7 @@ def test_tuple_var_len_errors(input_value, index): ], ) def test_tuple_fix_len_errors(input_value, items, index): - v = SchemaValidator({'type': 'tuple-positional', 'items_schema': items}) + v = SchemaValidator({'type': 'tuple', 'items_schema': items}) with pytest.raises(ValidationError) as exc_info: assert v.validate_python(input_value) assert exc_info.value.errors(include_url=False) == [ @@ -218,10 +222,7 @@ def test_tuple_fix_len_errors(input_value, items, index): def test_multiple_missing(py_and_json: PyAndJson): v = py_and_json( - { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}, {'type': 'int'}], - } + {'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}, {'type': 'int'}]} ) assert v.validate_test([1, 2, 3, 4]) == (1, 2, 3, 4) with pytest.raises(ValidationError) as exc_info: @@ -239,7 +240,7 @@ def test_multiple_missing(py_and_json: PyAndJson): def test_extra_arguments(py_and_json: PyAndJson): - v = py_and_json({'type': 'tuple-positional', 'items_schema': [{'type': 'int'}, {'type': 'int'}]}) + v = py_and_json({'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'int'}]}) assert v.validate_test([1, 2]) == (1, 2) with pytest.raises(ValidationError) as exc_info: v.validate_test([1, 2, 3, 4]) @@ -256,7 +257,7 @@ def test_extra_arguments(py_and_json: PyAndJson): def test_positional_empty(py_and_json: PyAndJson): - v = py_and_json({'type': 'tuple-positional', 'items_schema': []}) + v = py_and_json({'type': 'tuple', 'items_schema': []}) assert v.validate_test([]) == () assert v.validate_python(()) == () with pytest.raises(ValidationError, match='type=too_long,'): @@ -264,7 +265,7 @@ def test_positional_empty(py_and_json: PyAndJson): def test_positional_empty_extra(py_and_json: PyAndJson): - v = py_and_json({'type': 'tuple-positional', 'items_schema': [], 'extras_schema': {'type': 'int'}}) + v = py_and_json({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) assert v.validate_test([]) == () assert v.validate_python(()) == () assert v.validate_test([1]) == (1,) @@ -273,7 +274,15 @@ def test_positional_empty_extra(py_and_json: PyAndJson): @pytest.mark.parametrize('input_value,expected', [((1, 2, 3), (1, 2, 3)), ([1, 2, 3], [1, 2, 3])]) def test_union_tuple_list(input_value, expected): - v = SchemaValidator({'type': 'union', 'choices': [{'type': 'tuple-variable'}, {'type': 'list'}]}) + v = SchemaValidator( + { + 'type': 'union', + 'choices': [ + {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, + {'type': 'list'}, + ], + } + ) assert v.validate_python(input_value) == expected @@ -313,8 +322,8 @@ def test_union_tuple_var_len(input_value, expected): { 'type': 'union', 'choices': [ - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}, 'strict': True}, - {'type': 'tuple-variable', 'items_schema': {'type': 'str'}, 'strict': True}, + {'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0, 'strict': True}, + {'type': 'tuple', 'items_schema': [{'type': 'str'}], 'variadic_item_index': 0, 'strict': True}, ], } ) @@ -360,16 +369,8 @@ def test_union_tuple_fix_len(input_value, expected): { 'type': 'union', 'choices': [ - { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}], - 'strict': True, - }, - { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'str'}, {'type': 'str'}, {'type': 'str'}], - 'strict': True, - }, + {'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'int'}, {'type': 'int'}], 'strict': True}, + {'type': 'tuple', 'items_schema': [{'type': 'str'}, {'type': 'str'}, {'type': 'str'}], 'strict': True}, ], } ) @@ -383,7 +384,7 @@ def test_union_tuple_fix_len(input_value, expected): def test_tuple_fix_error(): - v = SchemaValidator({'type': 'tuple-positional', 'items_schema': [{'type': 'int'}, {'type': 'str'}]}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'str'}]}) with pytest.raises(ValidationError) as exc_info: v.validate_python([1]) @@ -403,13 +404,9 @@ def test_tuple_fix_error(): ([1], Err('type=missing', errors=[{'type': 'missing', 'loc': (1,), 'msg': 'Field required', 'input': [1]}])), ], ) -def test_tuple_fix_extra(input_value, expected, cache): +def test_tuple_fix_extra(input_value, expected): v = SchemaValidator( - { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'int'}, {'type': 'str'}], - 'extras_schema': {'type': 'str'}, - } + {'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'str'}, {'type': 'str'}], 'variadic_item_index': 2} ) if isinstance(expected, Err): @@ -421,9 +418,7 @@ def test_tuple_fix_extra(input_value, expected, cache): def test_tuple_fix_extra_any(): - v = SchemaValidator( - {'type': 'tuple-positional', 'items_schema': [{'type': 'str'}], 'extras_schema': {'type': 'any'}} - ) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'str'}, {'type': 'any'}], 'variadic_item_index': 1}) assert v.validate_python([b'1']) == ('1',) assert v.validate_python([b'1', 2]) == ('1', 2) assert v.validate_python((b'1', 2)) == ('1', 2) @@ -443,7 +438,7 @@ def gen(error: bool): raise RuntimeError('error') yield 3 - v = SchemaValidator({'type': 'tuple-variable', 'items_schema': {'type': 'int'}}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}) assert v.validate_python(gen(False)) == (1, 2, 3) msg = r'Error iterating over object, error: RuntimeError: error \[type=iteration_error,' @@ -456,13 +451,13 @@ def gen(error: bool): [ pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'any'}}, + {'type': 'tuple', 'items_schema': [{'type': 'any'}], 'variadic_item_index': 0}, ((1, 10), (2, 20), ('3', '30')), id='Tuple[Any, Any]', ), pytest.param( {1: 10, 2: 20, '3': '30'}.items(), - {'type': 'tuple-variable', 'items_schema': {'type': 'int'}}, + {'type': 'tuple', 'items_schema': [{'type': 'int'}], 'variadic_item_index': 0}, ((1, 10), (2, 20), (3, 30)), id='Tuple[int, int]', ), @@ -470,7 +465,7 @@ def gen(error: bool): ], ) def test_frozenset_from_dict_items(input_value, items_schema, expected): - v = SchemaValidator({'type': 'tuple-variable', 'items_schema': items_schema}) + v = SchemaValidator({'type': 'tuple', 'items_schema': [items_schema], 'variadic_item_index': 0}) output = v.validate_python(input_value) assert isinstance(output, tuple) assert output == expected @@ -487,8 +482,9 @@ def test_frozenset_from_dict_items(input_value, items_schema, expected): def test_length_constraints_omit(input_value, expected): v = SchemaValidator( { - 'type': 'tuple-variable', - 'items_schema': {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}, + 'type': 'tuple', + 'items_schema': [{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}], + 'variadic_item_index': 0, 'max_length': 4, } ) diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 7ca0d9f54..189a555bb 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -164,7 +164,11 @@ def test_dict_keys(): def test_tuple_variable(py_and_json: PyAndJson): v = py_and_json( - {'type': 'tuple-variable', 'items_schema': {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}} + { + 'type': 'tuple', + 'items_schema': [{'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}], + 'variadic_item_index': 0, + } ) assert v.validate_python((1, 2, 3)) == (1, 2, 3) assert v.validate_python([1, '2', 3]) == (1, 2, 3) @@ -174,7 +178,7 @@ def test_tuple_variable(py_and_json: PyAndJson): def test_tuple_positional(): v = SchemaValidator( { - 'type': 'tuple-positional', + 'type': 'tuple', 'items_schema': [{'type': 'int'}, {'type': 'default', 'schema': {'type': 'int'}, 'default': 42}], } ) @@ -187,9 +191,13 @@ def test_tuple_positional(): def test_tuple_positional_omit(): v = SchemaValidator( { - 'type': 'tuple-positional', - 'items_schema': [{'type': 'int'}, {'type': 'int'}], - 'extras_schema': {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}, + 'type': 'tuple', + 'items_schema': [ + {'type': 'int'}, + {'type': 'int'}, + {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'omit'}, + ], + 'variadic_item_index': 2, } ) assert v.validate_python((1, '2')) == (1, 2) From 8dde89e06523f89a803d7753d24096839f6039bc Mon Sep 17 00:00:00 2001 From: Sukhorosov Aleksey <39103632+alexdrydew@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:15:44 +0100 Subject: [PATCH 173/550] Use stricter serializer for unions of simple types (#1132) --- src/serializers/type_serializers/simple.rs | 28 +++-- tests/serializers/test_union.py | 116 +++++++++++++++++++++ 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/src/serializers/type_serializers/simple.rs b/src/serializers/type_serializers/simple.rs index dafb2b786..65fbee146 100644 --- a/src/serializers/type_serializers/simple.rs +++ b/src/serializers/type_serializers/simple.rs @@ -5,11 +5,12 @@ use std::borrow::Cow; use serde::Serialize; +use crate::PydanticSerializationUnexpectedValue; use crate::{definitions::DefinitionsBuilder, input::Int}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, - SerMode, TypeSerializer, + SerCheck, SerMode, TypeSerializer, }; #[derive(Debug, Clone)] @@ -85,7 +86,7 @@ impl TypeSerializer for NoneSerializer { } macro_rules! build_simple_serializer { - ($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident) => { + ($struct_name:ident, $expected_type:literal, $rust_type:ty, $ob_type:expr, $key_method:ident, $subtypes_allowed:expr) => { #[derive(Debug, Clone)] pub struct $struct_name; @@ -114,12 +115,15 @@ macro_rules! build_simple_serializer { let py = value.py(); match extra.ob_type_lookup.is_type(value, $ob_type) { IsType::Exact => Ok(value.into_py(py)), - IsType::Subclass => match extra.mode { - SerMode::Json => { - let rust_value = value.extract::<$rust_type>()?; - Ok(rust_value.to_object(py)) - } - _ => infer_to_python(value, include, exclude, extra), + IsType::Subclass => match extra.check { + SerCheck::Strict => Err(PydanticSerializationUnexpectedValue::new_err(None)), + SerCheck::Lax | SerCheck::None => match extra.mode { + SerMode::Json => { + let rust_value = value.extract::<$rust_type>()?; + Ok(rust_value.to_object(py)) + } + _ => infer_to_python(value, include, exclude, extra), + }, }, IsType::False => { extra.warnings.on_fallback_py(self.get_name(), value, extra)?; @@ -160,6 +164,10 @@ macro_rules! build_simple_serializer { fn get_name(&self) -> &str { Self::EXPECTED_TYPE } + + fn retry_with_lax_check(&self) -> bool { + $subtypes_allowed + } } }; } @@ -168,7 +176,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult> { Ok(key.str()?.to_string_lossy()) } -build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key); +build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key, true); pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { let v = if key.is_true().unwrap_or(false) { @@ -179,4 +187,4 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { Ok(Cow::Borrowed(v)) } -build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key); +build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key, false); diff --git a/tests/serializers/test_union.py b/tests/serializers/test_union.py index 9b021e66e..ee5ed3fc4 100644 --- a/tests/serializers/test_union.py +++ b/tests/serializers/test_union.py @@ -1,6 +1,8 @@ import dataclasses import json import re +import uuid +from decimal import Decimal from typing import Any, ClassVar, Union import pytest @@ -510,3 +512,117 @@ class Item(BaseModel): ) assert s.to_python([DBUser(name='John', password='secret')]) == [{'name': 'John'}] + + +EXAMPLE_UUID = uuid.uuid4() + + +class IntSubclass(int): + pass + + +@pytest.mark.parametrize('reverse', [False, True]) +@pytest.mark.parametrize( + 'core_schema_left,core_schema_right,input_value,expected_value', + [ + (core_schema.int_schema(), core_schema.bool_schema(), True, True), + (core_schema.int_schema(), core_schema.bool_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), '1', '1'), + (core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1), + ( + core_schema.decimal_schema(), + core_schema.int_schema(), + Decimal('1'), + Decimal('1'), + ), + (core_schema.decimal_schema(), core_schema.int_schema(), 1, 1), + ( + core_schema.decimal_schema(), + core_schema.float_schema(), + Decimal('1.'), + Decimal('1.'), + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + Decimal('_1'), + Decimal('_1'), + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + '_1', + '_1', + ), + ( + core_schema.uuid_schema(), + core_schema.str_schema(), + EXAMPLE_UUID, + EXAMPLE_UUID, + ), + ( + core_schema.uuid_schema(), + core_schema.str_schema(), + str(EXAMPLE_UUID), + str(EXAMPLE_UUID), + ), + ], +) +def test_union_serializer_picks_exact_type_over_subclass( + core_schema_left, core_schema_right, input_value, expected_value, reverse +): + s = SchemaSerializer( + core_schema.union_schema( + [core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right] + ) + ) + assert s.to_python(input_value) == expected_value + + +@pytest.mark.parametrize('reverse', [False, True]) +@pytest.mark.parametrize( + 'core_schema_left,core_schema_right,input_value,expected_value', + [ + (core_schema.int_schema(), core_schema.bool_schema(), True, True), + (core_schema.int_schema(), core_schema.bool_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), 1, 1), + (core_schema.str_schema(), core_schema.int_schema(), '1', '1'), + (core_schema.int_schema(), core_schema.bool_schema(), IntSubclass(1), 1), + ( + core_schema.decimal_schema(), + core_schema.int_schema(), + Decimal('1'), + '1', + ), + (core_schema.decimal_schema(), core_schema.int_schema(), 1, 1), + ( + core_schema.decimal_schema(), + core_schema.float_schema(), + Decimal('1.'), + '1', + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + Decimal('_1'), + '1', + ), + ( + core_schema.decimal_schema(), + core_schema.str_schema(), + '_1', + '_1', + ), + ], +) +def test_union_serializer_picks_exact_type_over_subclass_json( + core_schema_left, core_schema_right, input_value, expected_value, reverse +): + s = SchemaSerializer( + core_schema.union_schema( + [core_schema_right, core_schema_left] if reverse else [core_schema_left, core_schema_right] + ) + ) + assert s.to_python(input_value, mode='json') == expected_value + assert s.to_json(input_value) == json.dumps(expected_value).encode() From 82c4ca35f6281bd49117e86e4371ee78dcc8c65c Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 10 Jan 2024 13:46:15 +0000 Subject: [PATCH 174/550] adjust build for upload-artifact-4 (#1146) --- .github/workflows/ci.yml | 52 ++++++++++------------------------------ 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bf0698a1..826635499 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -358,7 +358,7 @@ jobs: rust-toolchain: stable - uses: actions/upload-artifact@v4 with: - name: pypi_files + name: pypi_files_sdist path: dist build: @@ -447,7 +447,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - manylinux: ${{ matrix.manylinux == 'manylinux' && 'auto' || matrix.manylinux }} + manylinux: ${{ matrix.manylinux }} args: --release --out dist --interpreter ${{ matrix.interpreter || '3.8 3.9 3.10 3.11 3.12 pypy3.9 pypy3.10' }} rust-toolchain: stable docker-options: -e CI @@ -458,7 +458,7 @@ jobs: - uses: actions/upload-artifact@v4 with: - name: pypi_files + name: pypi_files_${{ matrix.os }}_${{ matrix.target }}_${{ matrix.interpreter || 'all' }}_${{ matrix.manylinux }} path: dist build-pgo: @@ -553,7 +553,7 @@ jobs: - uses: actions/upload-artifact@v4 with: - name: pypi_files_pgo + name: pypi_files_${{ matrix.os }}_${{ matrix.interpreter }} path: dist inspect-pypi-assets: @@ -566,22 +566,11 @@ jobs: - name: get dist artifacts uses: actions/download-artifact@v4 with: - name: pypi_files + pattern: pypi_files_* + merge-multiple: true path: dist - - name: list dist files before PGO builds - run: | - ls -lh dist/ - ls -l dist/ - echo "`ls dist | wc -l` files" - - - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v4 - with: - name: pypi_files_pgo - path: dist - - - name: list dist files with PGO builds + - name: list dist files run: | ls -lh dist/ ls -l dist/ @@ -618,13 +607,8 @@ jobs: - name: get dist artifacts uses: actions/download-artifact@v4 with: - name: pypi_files - path: dist - - - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v4 - with: - name: pypi_files_pgo + pattern: pypi_files_linux_* + merge-multiple: true path: dist - uses: uraimo/run-on-arch-action@v2.6.0 @@ -676,13 +660,8 @@ jobs: - name: get dist artifacts uses: actions/download-artifact@v4 with: - name: pypi_files - path: dist - - - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v4 - with: - name: pypi_files_pgo + pattern: pypi_files_* + merge-multiple: true path: dist - run: pip install typing-extensions @@ -711,13 +690,8 @@ jobs: - name: get dist artifacts uses: actions/download-artifact@v4 with: - name: pypi_files - path: dist - - - name: get PGO dist artifacts (comes after "get dist artifacts" to so these files override the non-PGO builds) - uses: actions/download-artifact@v4 - with: - name: pypi_files_pgo + pattern: pypi_files_* + merge-multiple: true path: dist - run: twine check --strict dist/* From 700e17d8ce69caed42776365874006a9dfb90cf3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 12 Jan 2024 22:23:29 +0100 Subject: [PATCH 175/550] Group dependencies on dependabot updates (#1149) --- .github/dependabot.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b93ab648d..67fb85756 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,6 +9,10 @@ updates: directory: "/" schedule: interval: "monthly" + groups: + python-packages: + patterns: + - "*" - package-ecosystem: "github-actions" directory: "/" From 8030c5ae1cfb0517d333d5647ee66af8ab820440 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 14 Jan 2024 21:38:48 +0000 Subject: [PATCH 176/550] add to_json ser benchmark --- tests/benchmarks/test_serialization_micro.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/benchmarks/test_serialization_micro.py b/tests/benchmarks/test_serialization_micro.py index 6a78de87a..9200fc24b 100644 --- a/tests/benchmarks/test_serialization_micro.py +++ b/tests/benchmarks/test_serialization_micro.py @@ -4,7 +4,7 @@ import pytest -from pydantic_core import SchemaSerializer, SchemaValidator, core_schema +from pydantic_core import SchemaSerializer, SchemaValidator, core_schema, to_json class TestBenchmarkSimpleModel: @@ -394,3 +394,18 @@ def test_filter(benchmark): @benchmark def t(): v.to_python(['a', 'b', 'c', 'd', 'e'], include={-1, -2}) + + +@pytest.mark.benchmark(group='list-of-lists') +def test_to_json_list_of_lists(benchmark): + data = [[i + j for j in range(10)] for i in range(1000)] + + benchmark(to_json, data) + + +@pytest.mark.benchmark(group='list-of-lists') +def test_ser_list_of_lists(benchmark): + s = SchemaSerializer(core_schema.list_schema(core_schema.list_schema(core_schema.int_schema()))) + data = [[i + j for j in range(10)] for i in range(1000)] + + benchmark(s.to_json, data) From fa65fcdbbe70cd2eff9f94e156735e4a67545c6e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 09:52:03 +0000 Subject: [PATCH 177/550] simplify instantiation of undefined type (#1157) --- src/argument_markers.rs | 6 ------ src/validators/model.rs | 6 ++++-- src/validators/with_default.rs | 4 +++- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/argument_markers.rs b/src/argument_markers.rs index a4472351d..3e73fb629 100644 --- a/src/argument_markers.rs +++ b/src/argument_markers.rs @@ -102,9 +102,3 @@ impl PydanticUndefinedType { "PydanticUndefined" } } - -impl PydanticUndefinedType { - pub fn py_undefined() -> Py { - Python::with_gil(PydanticUndefinedType::new) - } -} diff --git a/src/validators/model.rs b/src/validators/model.rs index 1831d2fde..0cc39b6a3 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -60,6 +60,7 @@ pub struct ModelValidator { frozen: bool, custom_init: bool, root_model: bool, + undefined: PyObject, name: String, } @@ -93,6 +94,7 @@ impl BuildValidator for ModelValidator { frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), custom_init: schema.get_as(intern!(py, "custom_init"))?.unwrap_or(false), root_model: schema.get_as(intern!(py, "root_model"))?.unwrap_or(false), + undefined: PydanticUndefinedType::new(py).to_object(py), // Get the class's `__name__`, not using `class.name()` since it uses `__qualname__` // which is not what we want here name: class.getattr(intern!(py, "__name__"))?.extract()?, @@ -229,7 +231,7 @@ impl ModelValidator { let output = self.validator.validate(py, input, state)?; if self.root_model { - let fields_set = if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { + let fields_set = if input.to_object(py).is(&self.undefined) { PySet::empty(py)? } else { PySet::new(py, [&String::from(ROOT_FIELD)])? @@ -270,7 +272,7 @@ impl ModelValidator { let instance_ref = instance.as_ref(py); if self.root_model { - let fields_set = if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { + let fields_set = if input.to_object(py).is(&self.undefined) { PySet::empty(py)? } else { PySet::new(py, [&String::from(ROOT_FIELD)])? diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index 443c0e271..a17c2cefc 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -74,6 +74,7 @@ pub struct WithDefaultValidator { validate_default: bool, copy_default: bool, name: String, + undefined: PyObject, } impl BuildValidator for WithDefaultValidator { @@ -118,6 +119,7 @@ impl BuildValidator for WithDefaultValidator { validate_default: schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or(false), copy_default, name, + undefined: PydanticUndefinedType::new(py).to_object(py), } .into()) } @@ -132,7 +134,7 @@ impl Validator for WithDefaultValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult { - if input.to_object(py).is(&PydanticUndefinedType::py_undefined()) { + if input.to_object(py).is(&self.undefined) { Ok(self.default_value(py, None::, state)?.unwrap()) } else { match self.validator.validate(py, input, state) { From b9b28ad69fa2ffb02f5e294c5f4297291ec0bff2 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Jan 2024 10:24:29 +0000 Subject: [PATCH 178/550] unify 'profile.profiling' configuration (#1158) --- .github/workflows/codspeed.yml | 6 ++---- Cargo.toml | 4 +++- Makefile | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 5d69d4d5f..fc6591314 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -52,8 +52,7 @@ jobs: key: v1 - name: Compile pydantic-core for profiling - run: | - pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v + run: make build-profiling env: CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed RUSTFLAGS: "-Cprofile-generate=${{ github.workspace }}/profdata" @@ -65,8 +64,7 @@ jobs: run: rustup run stable bash -c '$RUSTUP_HOME/toolchains/$RUSTUP_TOOLCHAIN/lib/rustlib/x86_64-unknown-linux-gnu/bin/llvm-profdata merge -o ${{ github.workspace }}/merged.profdata ${{ github.workspace }}/profdata' - name: Compile pydantic-core for benchmarking - run: | - pip install -e . --config-settings=build-args='--verbose --profile codspeed' -v + run: make build-profiling env: CONST_RANDOM_SEED: 0 # Fix the compile time RNG seed RUSTFLAGS: "-Cprofile-use=${{ github.workspace }}/merged.profdata" diff --git a/Cargo.toml b/Cargo.toml index 66a8667e0..9bad92573 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,9 @@ strip = true debug = true strip = false -[profile.codspeed] +# This is separate to benchmarks because `bench` ends up building testing +# harnesses into code, as it's a special cargo profile. +[profile.profiling] inherits = "release" debug = true strip = false diff --git a/Makefile b/Makefile index e1f0bc607..ba8adca62 100644 --- a/Makefile +++ b/Makefile @@ -48,9 +48,9 @@ endif build-profiling: @rm -f python/pydantic_core/*.so ifneq ($(USE_MATURIN),) - CARGO_PROFILE_RELEASE_STRIP=false CARGO_PROFILE_RELEASE_DEBUG=true maturin develop --release + maturin develop '--profile profiling' else - CARGO_PROFILE_RELEASE_STRIP=false CARGO_PROFILE_RELEASE_DEBUG=true pip install -v -e . + pip install -v -e . --config-settings=build-args='--profile profiling' endif .PHONY: build-coverage From 3f7b0b428d692cfedf67c10ef2fa9571a98f6050 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:28:42 +0000 Subject: [PATCH 179/550] Bump the python-packages group with 4 updates (#1154) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 6 +++--- tests/requirements.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index e1e5e4268..043482646 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -griffe==0.38.0 -pyright==1.1.339 -ruff==0.1.7 +griffe==0.38.1 +pyright==1.1.345 +ruff==0.1.13 mypy==1.8.0 diff --git a/tests/requirements.txt b/tests/requirements.txt index 4297177d8..29bf5bcad 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,6 +1,6 @@ coverage==7.4.0 dirty-equals==0.7.1.post0 -hypothesis==6.92.5 +hypothesis==6.92.9 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is # released which removes use of deprecated utcfromtimestamp git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f From a1f6023d0acca9748c4f6120a336cde916d199a6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:28:55 +0000 Subject: [PATCH 180/550] Bump base64 from 0.21.5 to 0.21.7 (#1153) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f17cec97..ab1212242 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,9 +32,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.5" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "bitflags" diff --git a/Cargo.toml b/Cargo.toml index 9bad92573..905122b1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ ahash = "0.8.7" url = "2.5.0" # idna is already required by url, added here to be explicit idna = "0.5.0" -base64 = "0.21.5" +base64 = "0.21.7" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.6.1" From 1ff756f03a6e91c6996c98baa3a55111ec536401 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:29:08 +0000 Subject: [PATCH 181/550] Bump pyo3 from 0.20.1 to 0.20.2 (#1152) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 20 ++++++++++---------- Cargo.toml | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab1212242..c80c15e4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -346,9 +346,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" dependencies = [ "cfg-if", "indoc", @@ -364,9 +364,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" dependencies = [ "once_cell", "python3-dll-a", @@ -375,9 +375,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" dependencies = [ "libc", "pyo3-build-config", @@ -385,9 +385,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -397,9 +397,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.1" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 905122b1d..0caafc668 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ include = [ ] [dependencies] -pyo3 = { version = "0.20.1", features = ["generate-import-lib", "num-bigint"] } +pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.2" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" @@ -70,7 +70,7 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version = "0.20.1", features = ["auto-initialize"] } +pyo3 = { version = "0.20.2", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" From 545f8c3f8c476673e0068c48006509a36539fcb2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:29:23 +0000 Subject: [PATCH 182/550] Bump pyo3-build-config from 0.20.1 to 0.20.2 (#1151) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0caafc668..8f31c28b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,7 +75,7 @@ pyo3 = { version = "0.20.2", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy -pyo3-build-config = { version = "0.20.1" } +pyo3-build-config = { version = "0.20.2" } [lints.clippy] dbg_macro = "warn" From 579166767174e267d7b68f3ba01d5d6de0158b10 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:29:32 +0000 Subject: [PATCH 183/550] Bump serde from 1.0.193 to 1.0.195 (#1150) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 20 ++++++++++---------- Cargo.toml | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c80c15e4e..34b46ee65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -312,9 +312,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" dependencies = [ "unicode-ident", ] @@ -418,9 +418,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.29" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -483,18 +483,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" dependencies = [ "proc-macro2", "quote", @@ -559,9 +559,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.38" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 8f31c28b3..283e364f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.193", features = ["derive"] } +serde = { version = "1.0.195", features = ["derive"] } speedate = "0.13.0" smallvec = "1.11.2" ahash = "0.8.7" From 5d3aa43f4b5421648d657cd316b6c237cc435ea1 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Jan 2024 10:55:33 +0000 Subject: [PATCH 184/550] correct build-profiling make command (#1160) --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index ba8adca62..8fbe91389 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ endif build-profiling: @rm -f python/pydantic_core/*.so ifneq ($(USE_MATURIN),) - maturin develop '--profile profiling' + maturin develop --profile profiling else pip install -v -e . --config-settings=build-args='--profile profiling' endif From d7cf72d8d591632889967b37eb4367f9f9f889e7 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 12:23:05 +0000 Subject: [PATCH 185/550] Int extraction (#1155) --- .gitignore | 3 +++ src/errors/types.rs | 2 +- src/errors/value_exception.rs | 2 +- src/input/input_python.rs | 10 ++++---- src/input/return_enums.rs | 6 ++--- src/lookup_key.rs | 2 +- src/serializers/infer.rs | 5 +++- src/serializers/type_serializers/literal.rs | 4 +-- src/tools.rs | 28 +++++++++++++++------ tests/benchmarks/test_micro_benchmarks.py | 12 +++++++++ tests/validators/test_int.py | 15 ++++++++--- 11 files changed, 64 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 6c9ace5f4..efffcbf69 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ node_modules/ /foobar.py /python/pydantic_core/*.so /src/self_schema.py + +# samply +/profile.json diff --git a/src/errors/types.rs b/src/errors/types.rs index cfa96221e..eddd7dbaa 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -786,7 +786,7 @@ impl From for Number { impl FromPyObject<'_> for Number { fn extract(obj: &PyAny) -> PyResult { - if let Ok(int) = extract_i64(obj) { + if let Some(int) = extract_i64(obj) { Ok(Number::Int(int)) } else if let Ok(float) = obj.extract::() { Ok(Number::Float(float)) diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index a88610eef..68f93d463 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -122,7 +122,7 @@ impl PydanticCustomError { let key: &PyString = key.downcast()?; if let Ok(py_str) = value.downcast::() { message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?); - } else if let Ok(value_int) = extract_i64(value) { + } else if let Some(value_int) = extract_i64(value) { message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string()); } else { // fallback for anything else just in case diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 5d4d4826c..2eaaebe10 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -96,7 +96,7 @@ impl AsLocItem for PyAny { fn as_loc_item(&self) -> LocItem { if let Ok(py_str) = self.downcast::() { py_str.to_string_lossy().as_ref().into() - } else if let Ok(key_int) = extract_i64(self) { + } else if let Some(key_int) = extract_i64(self) { key_int.into() } else { safe_repr(self).to_string().into() @@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny { if !strict { if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? { return str_as_bool(self, &cow_str).map(ValidationMatch::lax); - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { return int_as_bool(self, int).map(ValidationMatch::lax); } else if let Ok(float) = self.extract::() { if let Ok(int) = float_as_int(self, float) { @@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior) } else if PyBool::is_exact_type_of(self) { Err(ValError::new(ErrorTypeDefaults::TimeType, self)) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { int_as_time(self, int, 0) } else if let Ok(float) = self.extract::() { float_as_time(self, float) @@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior) } else if PyBool::is_exact_type_of(self) { Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { int_as_datetime(self, int, 0) } else if let Ok(float) = self.extract::() { float_as_datetime(self, float) @@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny { bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior) } else if let Ok(py_bytes) = self.downcast::() { bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior) - } else if let Ok(int) = extract_i64(self) { + } else if let Some(int) = extract_i64(self) { Ok(int_as_duration(self, int)?.into()) } else if let Ok(float) = self.extract::() { Ok(float_as_duration(self, float)?.into()) diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index fa70880ca..905b895f9 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -23,7 +23,7 @@ use pyo3::PyTypeInfo; use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult}; -use crate::tools::py_err; +use crate::tools::{extract_i64, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; use super::input_string::StringMapping; @@ -863,7 +863,7 @@ pub enum EitherInt<'a> { impl<'a> EitherInt<'a> { pub fn upcast(py_any: &'a PyAny) -> ValResult { // Safety: we know that py_any is a python int - if let Ok(int_64) = py_any.extract::() { + if let Some(int_64) = extract_i64(py_any) { Ok(Self::I64(int_64)) } else { let big_int: BigInt = py_any.extract()?; @@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int { impl<'a> FromPyObject<'a> for Int { fn extract(obj: &'a PyAny) -> PyResult { - if let Ok(i) = obj.extract::() { + if let Some(i) = extract_i64(obj) { Ok(Int::I64(i)) } else if let Ok(b) = obj.extract::() { Ok(Int::Big(b)) diff --git a/src/lookup_key.rs b/src/lookup_key.rs index e145c1f41..e4d2dcce7 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -429,7 +429,7 @@ impl PathItem { } else { Ok(Self::Pos(usize_key)) } - } else if let Ok(int_key) = extract_i64(obj) { + } else if let Some(int_key) = extract_i64(obj) { if index == 0 { py_err!(PyTypeError; "The first item in an alias path should be a string") } else { diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 13c20062b..e39ca38f8 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known( // `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py), // have to do this to make sure subclasses of for example str are upcast to `str` - ObType::IntSubclass => extract_i64(value)?.into_py(py), + ObType::IntSubclass => match extract_i64(value) { + Some(v) => v.into_py(py), + None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)), + }, ObType::Float | ObType::FloatSubclass => { let v = value.extract::()?; if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null { diff --git a/src/serializers/type_serializers/literal.rs b/src/serializers/type_serializers/literal.rs index 846b8843f..d6b08afa5 100644 --- a/src/serializers/type_serializers/literal.rs +++ b/src/serializers/type_serializers/literal.rs @@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer { repr_args.push(item.repr()?.extract()?); if let Ok(bool) = item.downcast::() { expected_py.append(bool)?; - } else if let Ok(int) = extract_i64(item) { + } else if let Some(int) = extract_i64(item) { expected_int.insert(int); } else if let Ok(py_str) = item.downcast::() { expected_str.insert(py_str.to_str()?.to_string()); @@ -79,7 +79,7 @@ impl LiteralSerializer { fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult> { if extra.check.enabled() { if !self.expected_int.is_empty() && !PyBool::is_type_of(value) { - if let Ok(int) = extract_i64(value) { + if let Some(int) = extract_i64(value) { if self.expected_int.contains(&int) { return Ok(OutputValue::OkInt(int)); } diff --git a/src/tools.rs b/src/tools.rs index af58131f5..bdc41583c 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -1,9 +1,9 @@ use std::borrow::Cow; -use pyo3::exceptions::{PyKeyError, PyTypeError}; +use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyInt, PyString}; -use pyo3::{intern, FromPyObject, PyTypeInfo}; +use pyo3::types::{PyDict, PyString}; +use pyo3::{ffi, intern, FromPyObject}; pub trait SchemaDict<'py> { fn get_as(&'py self, key: &PyString) -> PyResult> @@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow { } } -pub fn extract_i64(v: &PyAny) -> PyResult { - if PyInt::is_type_of(v) { - v.extract() +/// Extract an i64 from a python object more quickly, see +/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928 +#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))] +pub fn extract_i64(obj: &PyAny) -> Option { + let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) }; + if val == -1 && PyErr::occurred(obj.py()) { + unsafe { ffi::PyErr_Clear() }; + None } else { - py_err!(PyTypeError; "expected int, got {}", safe_repr(v)) + Some(val) + } +} + +#[cfg(any(target_pointer_width = "32", windows, PyPy))] +pub fn extract_i64(v: &PyAny) -> Option { + if v.is_instance_of::() { + v.extract().ok() + } else { + None } } diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index f1ec32eef..c2320427c 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1232,6 +1232,18 @@ def test_strict_int(benchmark): benchmark(v.validate_python, 42) +@pytest.mark.benchmark(group='strict_int') +def test_strict_int_fails(benchmark): + v = SchemaValidator(core_schema.int_schema(strict=True)) + + @benchmark + def t(): + try: + v.validate_python(()) + except ValidationError: + pass + + @pytest.mark.benchmark(group='int_range') def test_int_range(benchmark): v = SchemaValidator(core_schema.int_schema(gt=0, lt=100)) diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 80dd1cf73..35a13f6a7 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -29,6 +29,8 @@ ('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')), (int(1e10), int(1e10)), (i64_max, i64_max), + (i64_max + 1, i64_max + 1), + (i64_max * 2, i64_max * 2), pytest.param( 12.5, Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'), @@ -106,10 +108,15 @@ def test_int(input_value, expected): @pytest.mark.parametrize( 'input_value,expected', [ - (Decimal('1'), 1), - (Decimal('1.0'), 1), - (i64_max, i64_max), - (i64_max + 1, i64_max + 1), + pytest.param(Decimal('1'), 1), + pytest.param(Decimal('1.0'), 1), + pytest.param(i64_max, i64_max, id='i64_max'), + pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'), + pytest.param( + -1, + Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'), + id='-1', + ), ( -i64_max + 1, Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'), From e1cb0ebe5633fee835cd149768ebf43adfde4f6b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 16:26:37 +0000 Subject: [PATCH 186/550] improve performance of recursion guard (#1156) Co-authored-by: David Hewitt Co-authored-by: David Hewitt --- src/recursion_guard.rs | 144 ++++++++++++++++++++++++++++------ src/serializers/extra.rs | 14 ++-- src/validators/definitions.rs | 16 ++-- tests/serializers/test_any.py | 2 +- 4 files changed, 134 insertions(+), 42 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 453f01a1d..fe5b1bcdd 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -1,4 +1,5 @@ use ahash::AHashSet; +use std::mem::MaybeUninit; type RecursionKey = ( // Identifier for the input object, e.g. the id() of a Python dict @@ -13,56 +14,147 @@ type RecursionKey = ( /// It's used in `validators/definition` to detect when a reference is reused within itself. #[derive(Debug, Clone, Default)] pub struct RecursionGuard { - ids: Option>, + ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators - depth: u16, + depth: u8, } // A hard limit to avoid stack overflows when rampant recursion occurs -pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { +pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { // wasm and windows PyPy have very limited stack sizes - 50 + 49 } else if cfg!(any(PyPy, windows)) { // PyPy and Windows in general have more restricted stack space - 100 + 99 } else { 255 }; impl RecursionGuard { - // insert a new id into the set, return whether the set already had the id in it - pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool { - match self.ids { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." - Some(ref mut set) => !set.insert((obj_id, node_id)), - None => { - let mut set: AHashSet = AHashSet::with_capacity(10); - set.insert((obj_id, node_id)); - self.ids = Some(set); - false - } - } + // insert a new value + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + self.ids.insert((obj_id, node_id)) } // see #143 this is used as a backup in case the identity check recursion guard fails #[must_use] + #[cfg(any(target_family = "wasm", windows, PyPy))] pub fn incr_depth(&mut self) -> bool { - self.depth += 1; - self.depth >= RECURSION_GUARD_LIMIT + // use saturating_add as it's faster (since there's no error path) + // and the RECURSION_GUARD_LIMIT check will be hit before it overflows + debug_assert!(RECURSION_GUARD_LIMIT < 255); + self.depth = self.depth.saturating_add(1); + self.depth > RECURSION_GUARD_LIMIT + } + + #[must_use] + #[cfg(not(any(target_family = "wasm", windows, PyPy)))] + pub fn incr_depth(&mut self) -> bool { + debug_assert_eq!(RECURSION_GUARD_LIMIT, 255); + // use checked_add to check if we've hit the limit + if let Some(depth) = self.depth.checked_add(1) { + self.depth = depth; + false + } else { + true + } } pub fn decr_depth(&mut self) { - self.depth -= 1; + // for the same reason as incr_depth, use saturating_sub + self.depth = self.depth.saturating_sub(1); } pub fn remove(&mut self, obj_id: usize, node_id: usize) { - match self.ids { - Some(ref mut set) => { - set.remove(&(obj_id, node_id)); + self.ids.remove(&(obj_id, node_id)); + } +} + +// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower +const ARRAY_SIZE: usize = 16; + +#[derive(Debug, Clone)] +enum RecursionStack { + Array { + data: [MaybeUninit; ARRAY_SIZE], + len: usize, + }, + Set(AHashSet), +} + +impl Default for RecursionStack { + fn default() -> Self { + Self::Array { + data: std::array::from_fn(|_| MaybeUninit::uninit()), + len: 0, + } + } +} + +impl RecursionStack { + // insert a new value + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, v: RecursionKey) -> bool { + match self { + Self::Array { data, len } => { + if *len < ARRAY_SIZE { + for value in data.iter().take(*len) { + // Safety: reading values within bounds + if unsafe { value.assume_init() } == v { + return false; + } + } + + data[*len].write(v); + *len += 1; + true + } else { + let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in data.iter() { + // Safety: the array is fully initialized + set.insert(unsafe { existing.assume_init() }); + } + let inserted = set.insert(v); + *self = Self::Set(set); + inserted + } + } + // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert + // "If the set did not have this value present, `true` is returned." + Self::Set(set) => set.insert(v), + } + } + + pub fn remove(&mut self, v: &RecursionKey) { + match self { + Self::Array { data, len } => { + *len = len.checked_sub(1).expect("remove from empty recursion guard"); + // Safety: this is reading what was the back of the initialized array + let removed = unsafe { data.get_unchecked_mut(*len) }; + assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert"); + // this should compile away to a noop + unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) } + } + Self::Set(set) => { + set.remove(v); + } + } + } +} + +impl Drop for RecursionStack { + fn drop(&mut self) { + // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed + // desirable to leave this in for safety in case that should change in the future + if let Self::Array { data, len } = self { + for value in data.iter_mut().take(*len) { + // Safety: reading values within bounds + unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) }; } - None => unreachable!(), - }; + } } } diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 37307055e..b3978613a 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -346,17 +346,17 @@ pub struct SerRecursionGuard { impl SerRecursionGuard { pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." let id = value.as_ptr() as usize; let mut guard = self.guard.borrow_mut(); - if guard.contains_or_insert(id, def_ref_id) { - Err(PyValueError::new_err("Circular reference detected (id repeated)")) - } else if guard.incr_depth() { - Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + if guard.insert(id, def_ref_id) { + if guard.incr_depth() { + Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + } else { + Ok(id) + } } else { - Ok(id) + Err(PyValueError::new_err("Circular reference detected (id repeated)")) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 0b5f78c10..e8c67a690 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -76,10 +76,7 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } @@ -87,6 +84,9 @@ impl Validator for DefinitionRefValidator { state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } } else { validator.validate(py, input, state) @@ -105,10 +105,7 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } @@ -116,6 +113,9 @@ impl Validator for DefinitionRefValidator { state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } } else { validator.validate_assignment(py, obj, field_name, field_value, state) diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 98ec22c1f..fa6e702fe 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -371,7 +371,7 @@ def fallback_func(obj): f = FoobarCount(0) v = 0 # when recursion is detected and we're in mode python, we just return the value - expected_visits = pydantic_core._pydantic_core._recursion_limit - 1 + expected_visits = pydantic_core._pydantic_core._recursion_limit assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'') with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): From 5a1385b86e28f6c3e30cf4a22a2b173a2ecb4742 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 16 Jan 2024 19:13:46 +0000 Subject: [PATCH 187/550] dataclass serialization speedups (#1162) --- src/serializers/fields.rs | 257 +++++++++++------- src/serializers/infer.rs | 130 +++++---- src/serializers/shared.rs | 42 +-- src/serializers/type_serializers/dataclass.rs | 74 +++-- tests/benchmarks/test_serialization_micro.py | 46 ++++ 5 files changed, 362 insertions(+), 187 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index f48f8e0b9..cefdec1d7 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -100,6 +100,15 @@ pub struct GeneralFieldsSerializer { required_fields: usize, } +macro_rules! option_length { + ($op_has_len:expr) => { + match $op_has_len { + Some(ref has_len) => has_len.len(), + None => 0, + } + }; +} + impl GeneralFieldsSerializer { pub(super) fn new( fields: AHashMap, @@ -136,50 +145,21 @@ impl GeneralFieldsSerializer { } } } -} - -macro_rules! option_length { - ($op_has_len:expr) => { - match $op_has_len { - Some(ref has_len) => has_len.len(), - None => 0, - } - }; -} -impl_py_gc_traverse!(GeneralFieldsSerializer { - fields, - computed_fields -}); - -impl TypeSerializer for GeneralFieldsSerializer { - fn to_python( + pub fn main_to_python<'py>( &self, - value: &PyAny, - include: Option<&PyAny>, - exclude: Option<&PyAny>, - extra: &Extra, - ) -> PyResult { - let py = value.py(); - // If there is already a model registered (from a dataclass, BaseModel) - // then do not touch it - // If there is no model, we (a TypedDict) are the model - let td_extra = Extra { - model: extra.model.map_or_else(|| Some(value), Some), - ..*extra - }; - let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) { - main_extra_dict - } else { - td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?; - return infer_to_python(value, include, exclude, &td_extra); - }; - + py: Python<'py>, + main_iter: impl Iterator>, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + extra: Extra, + ) -> PyResult<&'py PyDict> { let output_dict = PyDict::new(py); let mut used_req_fields: usize = 0; // NOTE! we maintain the order of the input dict assuming that's right - for (key, value) in main_dict { + for result in main_iter { + let (key, value) = result?; let key_str = key_str(key)?; let op_field = self.fields.get(key_str); if extra.exclude_none && value.is_none() { @@ -190,16 +170,16 @@ impl TypeSerializer for GeneralFieldsSerializer { } continue; } - let extra = Extra { + let field_extra = Extra { field_name: Some(key_str), - ..td_extra + ..extra }; if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? { if let Some(field) = op_field { if let Some(ref serializer) = field.serializer { - if !exclude_default(value, &extra, serializer)? { - let value = serializer.to_python(value, next_include, next_exclude, &extra)?; - let output_key = field.get_key_py(output_dict.py(), &extra); + if !exclude_default(value, &field_extra, serializer)? { + let value = serializer.to_python(value, next_include, next_exclude, &field_extra)?; + let output_key = field.get_key_py(output_dict.py(), &field_extra); output_dict.set_item(output_key, value)?; } } @@ -209,23 +189,140 @@ impl TypeSerializer for GeneralFieldsSerializer { } } else if self.mode == FieldsMode::TypedDictAllow { let value = match &self.extra_serializer { - Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?, - None => infer_to_python(value, next_include, next_exclude, &extra)?, + Some(serializer) => serializer.to_python(value, next_include, next_exclude, &field_extra)?, + None => infer_to_python(value, next_include, next_exclude, &field_extra)?, }; output_dict.set_item(key, value)?; - } else if extra.check == SerCheck::Strict { + } else if field_extra.check == SerCheck::Strict { return Err(PydanticSerializationUnexpectedValue::new_err(None)); } } } - if td_extra.check.enabled() + + if extra.check.enabled() // If any of these are true we can't count fields && !(extra.exclude_defaults || extra.exclude_unset || extra.exclude_none) // Check for missing fields, we can't have extra fields here && self.required_fields > used_req_fields { - return Err(PydanticSerializationUnexpectedValue::new_err(None)); + Err(PydanticSerializationUnexpectedValue::new_err(None)) + } else { + Ok(output_dict) + } + } + + pub fn main_serde_serialize<'py, S: serde::ser::Serializer>( + &self, + main_iter: impl Iterator>, + expected_len: usize, + serializer: S, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + extra: Extra, + ) -> Result { + // NOTE! As above, we maintain the order of the input dict assuming that's right + // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used + let mut map = serializer.serialize_map(Some(expected_len))?; + + for result in main_iter { + let (key, value) = result.map_err(py_err_se_err)?; + if extra.exclude_none && value.is_none() { + continue; + } + let key_str = key_str(key).map_err(py_err_se_err)?; + let field_extra = Extra { + field_name: Some(key_str), + ..extra + }; + + let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = filter { + if let Some(field) = self.fields.get(key_str) { + if let Some(ref serializer) = field.serializer { + if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? { + let s = + PydanticSerializer::new(value, serializer, next_include, next_exclude, &field_extra); + let output_key = field.get_key_json(key_str, &field_extra); + map.serialize_entry(&output_key, &s)?; + } + } + } else if self.mode == FieldsMode::TypedDictAllow { + let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?; + let s = SerializeInfer::new(value, next_include, next_exclude, &field_extra); + map.serialize_entry(&output_key, &s)?; + } + // no error case here since unions (which need the error case) use `to_python(..., mode='json')` + } } + Ok(map) + } + + pub fn add_computed_fields_python( + &self, + model: Option<&PyAny>, + output_dict: &PyDict, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult<()> { + if let Some(ref computed_fields) = self.computed_fields { + if let Some(model_value) = model { + let cf_extra = Extra { model, ..*extra }; + computed_fields.to_python(model_value, output_dict, &self.filter, include, exclude, &cf_extra)?; + } + } + Ok(()) + } + + pub fn add_computed_fields_json( + &self, + model: Option<&PyAny>, + map: &mut S::SerializeMap, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result<(), S::Error> { + if let Some(ref computed_fields) = self.computed_fields { + if let Some(model) = model { + computed_fields.serde_serialize::(model, map, &self.filter, include, exclude, extra)?; + } + } + Ok(()) + } + + pub fn computed_field_count(&self) -> usize { + option_length!(self.computed_fields) + } +} + +impl_py_gc_traverse!(GeneralFieldsSerializer { + fields, + computed_fields +}); + +impl TypeSerializer for GeneralFieldsSerializer { + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let py = value.py(); + // If there is already a model registered (from a dataclass, BaseModel) + // then do not touch it + // If there is no model, we (a TypedDict) are the model + let model = extra.model.map_or_else(|| Some(value), Some); + let td_extra = Extra { model, ..*extra }; + let (main_dict, extra_dict) = if let Some(main_extra_dict) = self.extract_dicts(value) { + main_extra_dict + } else { + td_extra.warnings.on_fallback_py(self.get_name(), value, &td_extra)?; + return infer_to_python(value, include, exclude, &td_extra); + }; + + let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?; + // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { for (key, value) in extra_dict { @@ -241,11 +338,7 @@ impl TypeSerializer for GeneralFieldsSerializer { } } } - if let Some(ref computed_fields) = self.computed_fields { - if let Some(model) = td_extra.model { - computed_fields.to_python(model, output_dict, &self.filter, include, exclude, &td_extra)?; - } - } + self.add_computed_fields_python(model, output_dict, include, exclude, extra)?; Ok(output_dict.into_py(py)) } @@ -271,46 +364,23 @@ impl TypeSerializer for GeneralFieldsSerializer { // If there is already a model registered (from a dataclass, BaseModel) // then do not touch it // If there is no model, we (a TypedDict) are the model - let td_extra = Extra { - model: extra.model.map_or_else(|| Some(value), Some), - ..*extra - }; + let model = extra.model.map_or_else(|| Some(value), Some); + let td_extra = Extra { model, ..*extra }; let expected_len = match self.mode { - FieldsMode::TypedDictAllow => main_dict.len() + option_length!(self.computed_fields), - _ => self.fields.len() + option_length!(extra_dict) + option_length!(self.computed_fields), + FieldsMode::TypedDictAllow => main_dict.len() + self.computed_field_count(), + _ => self.fields.len() + option_length!(extra_dict) + self.computed_field_count(), }; // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used - let mut map = serializer.serialize_map(Some(expected_len))?; - - for (key, value) in main_dict { - if extra.exclude_none && value.is_none() { - continue; - } - let key_str = key_str(key).map_err(py_err_se_err)?; - let extra = Extra { - field_name: Some(key_str), - ..td_extra - }; + let mut map = self.main_serde_serialize( + main_dict.iter().map(Ok), + expected_len, + serializer, + include, + exclude, + td_extra, + )?; - let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = filter { - if let Some(field) = self.fields.get(key_str) { - if let Some(ref serializer) = field.serializer { - if !exclude_default(value, &extra, serializer).map_err(py_err_se_err)? { - let s = PydanticSerializer::new(value, serializer, next_include, next_exclude, &extra); - let output_key = field.get_key_json(key_str, &extra); - map.serialize_entry(&output_key, &s)?; - } - } - } else if self.mode == FieldsMode::TypedDictAllow { - let output_key = infer_json_key(key, &extra).map_err(py_err_se_err)?; - let s = SerializeInfer::new(value, next_include, next_exclude, &extra); - map.serialize_entry(&output_key, &s)?; - } - // no error case here since unions (which need the error case) use `to_python(..., mode='json')` - } - } // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { for (key, value) in extra_dict { @@ -319,17 +389,14 @@ impl TypeSerializer for GeneralFieldsSerializer { } let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; if let Some((next_include, next_exclude)) = filter { - let output_key = infer_json_key(key, &td_extra).map_err(py_err_se_err)?; - let s = SerializeInfer::new(value, next_include, next_exclude, &td_extra); + let output_key = infer_json_key(key, extra).map_err(py_err_se_err)?; + let s = SerializeInfer::new(value, next_include, next_exclude, extra); map.serialize_entry(&output_key, &s)?; } } } - if let Some(ref computed_fields) = self.computed_fields { - if let Some(model) = td_extra.model { - computed_fields.serde_serialize::(model, &mut map, &self.filter, include, exclude, &td_extra)?; - } - } + + self.add_computed_fields_json::(model, &mut map, include, exclude, extra)?; map.end() } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index e39ca38f8..5ddf77597 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -10,19 +10,17 @@ use pyo3::types::{ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; use crate::input::{EitherTimedelta, Int}; -use crate::serializers::config::InfNanMode; -use crate::serializers::errors::SERIALIZATION_ERR_MARKER; -use crate::serializers::filter::SchemaFilter; -use crate::serializers::shared::{PydanticSerializer, TypeSerializer}; -use crate::serializers::SchemaSerializer; use crate::tools::{extract_i64, py_err, safe_repr}; use crate::url::{PyMultiHostUrl, PyUrl}; +use super::config::InfNanMode; +use super::errors::SERIALIZATION_ERR_MARKER; use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; -use super::filter::AnyFilter; +use super::filter::{AnyFilter, SchemaFilter}; use super::ob_type::ObType; -use super::shared::dataclass_to_dict; +use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer}; +use super::SchemaSerializer; pub(crate) fn infer_to_python( value: &PyAny, @@ -83,22 +81,6 @@ pub(crate) fn infer_to_python_known( }}; } - let serialize_dict = |dict: &PyDict| { - let new_dict = PyDict::new(py); - let filter = AnyFilter::new(); - - for (k, v) in dict { - let op_next = filter.key_filter(k, include, exclude)?; - if let Some((next_include, next_exclude)) = op_next { - let k_str = infer_json_key(k, extra)?; - let k = PyString::new(py, &k_str); - let v = infer_to_python(v, next_include, next_exclude, extra)?; - new_dict.set_item(k, v)?; - } - } - Ok::(new_dict.into_py(py)) - }; - let serialize_with_serializer = || { let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?; let serializer: PyRef = py_serializer.extract()?; @@ -168,7 +150,12 @@ pub(crate) fn infer_to_python_known( let elements = serialize_seq!(PyFrozenSet); PyList::new(py, elements).into_py(py) } - ObType::Dict => serialize_dict(value.downcast()?)?, + ObType::Dict => { + let dict: &PyDict = value.downcast()?; + serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, |k| { + Ok(PyString::new(py, &infer_json_key(k, extra)?)) + })? + } ObType::Datetime => { let py_dt: &PyDateTime = value.downcast()?; let iso_dt = super::type_serializers::datetime_etc::datetime_to_string(py_dt)?; @@ -205,7 +192,11 @@ pub(crate) fn infer_to_python_known( uuid.into_py(py) } ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, + ObType::Dataclass => { + serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, |k| { + Ok(PyString::new(py, &infer_json_key(k, extra)?)) + })? + } ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; infer_to_python(v, include, exclude, extra)?.into_py(py) @@ -256,22 +247,11 @@ pub(crate) fn infer_to_python_known( PyFrozenSet::new(py, &elements)?.into_py(py) } ObType::Dict => { - // different logic for keys from above let dict: &PyDict = value.downcast()?; - let new_dict = PyDict::new(py); - let filter = AnyFilter::new(); - - for (k, v) in dict { - let op_next = filter.key_filter(k, include, exclude)?; - if let Some((next_include, next_exclude)) = op_next { - let v = infer_to_python(v, next_include, next_exclude, extra)?; - new_dict.set_item(k, v)?; - } - } - new_dict.into_py(py) + serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, Ok)? } ObType::PydanticSerializable => serialize_with_serializer()?, - ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, + ObType::Dataclass => serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, Ok)?, ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -405,23 +385,6 @@ pub(crate) fn infer_serialize_known( }}; } - macro_rules! serialize_dict { - ($py_dict:expr) => {{ - let mut map = serializer.serialize_map(Some($py_dict.len()))?; - let filter = AnyFilter::new(); - - for (key, value) in $py_dict { - let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let key = infer_json_key(key, extra).map_err(py_err_se_err)?; - let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra); - map.serialize_entry(&key, &value_serializer)?; - } - } - map.end() - }}; - } - let ser_result = match ob_type { ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(Int), @@ -445,7 +408,10 @@ pub(crate) fn infer_serialize_known( .bytes_mode .serialize_bytes(unsafe { py_byte_array.as_bytes() }, serializer) } - ObType::Dict => serialize_dict!(value.downcast::().map_err(py_err_se_err)?), + ObType::Dict => { + let dict = value.downcast::().map_err(py_err_se_err)?; + serialize_pairs_json(dict.iter().map(Ok), dict.len(), serializer, include, exclude, extra) + } ObType::List => serialize_seq_filter!(PyList), ObType::Tuple => serialize_seq_filter!(PyTuple), ObType::Set => serialize_seq!(PySet), @@ -503,7 +469,10 @@ pub(crate) fn infer_serialize_known( PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); pydantic_serializer.serialize(serializer) } - ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?), + ObType::Dataclass => { + let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?; + serialize_pairs_json(pairs_iter, fields_dict.len(), serializer, include, exclude, extra) + } ObType::Uuid => { let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?; let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?; @@ -672,3 +641,50 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra: } } } + +fn serialize_pairs_python<'py>( + py: Python, + pairs_iter: impl Iterator>, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + key_transform: impl Fn(&'py PyAny) -> PyResult<&'py PyAny>, +) -> PyResult { + let new_dict = PyDict::new(py); + let filter = AnyFilter::new(); + + for result in pairs_iter { + let (k, v) = result?; + let op_next = filter.key_filter(k, include, exclude)?; + if let Some((next_include, next_exclude)) = op_next { + let k = key_transform(k)?; + let v = infer_to_python(v, next_include, next_exclude, extra)?; + new_dict.set_item(k, v)?; + } + } + Ok(new_dict.into_py(py)) +} + +fn serialize_pairs_json<'py, S: Serializer>( + pairs_iter: impl Iterator>, + iter_size: usize, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, +) -> Result { + let mut map = serializer.serialize_map(Some(iter_size))?; + let filter = AnyFilter::new(); + + for result in pairs_iter { + let (key, value) = result.map_err(py_err_se_err)?; + + let op_next = filter.key_filter(key, include, exclude).map_err(py_err_se_err)?; + if let Some((next_include, next_exclude)) = op_next { + let key = infer_json_key(key, extra).map_err(py_err_se_err)?; + let value_serializer = SerializeInfer::new(value, next_include, next_exclude, extra); + map.serialize_entry(&key, &value_serializer)?; + } + } + map.end() +} diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 11aac037d..7cfe6ce6e 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -364,29 +364,33 @@ pub(crate) fn to_json_bytes( Ok(bytes) } +pub(super) fn any_dataclass_iter<'py>( + dataclass: &'py PyAny, +) -> PyResult<(impl Iterator> + 'py, &PyDict)> { + let py = dataclass.py(); + let fields: &PyDict = dataclass.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let field_type_marker = get_field_marker(py)?; + + let next = move |(field_name, field): (&'py PyAny, &'py PyAny)| -> PyResult> { + let field_type = field.getattr(intern!(py, "_field_type"))?; + if field_type.is(field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + let value = dataclass.getattr(field_name)?; + Ok(Some((field_name, value))) + } else { + Ok(None) + } + }; + + Ok((fields.iter().filter_map(move |field| next(field).transpose()), fields)) +} + static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); /// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` -pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { +fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || { - let field_ = py.import("dataclasses")?.getattr("_FIELD")?; - Ok::(field_.into_py(py)) + py.import("dataclasses")?.getattr("_FIELD").map(|f| f.into_py(py)) })?; Ok(field_type_marker_obj.as_ref(py)) } - -pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { - let py = dc.py(); - let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; - let dict = PyDict::new(py); - - let field_type_marker = get_field_marker(py)?; - for (field_name, field) in dc_fields { - let field_type = field.getattr(intern!(py, "_field_type"))?; - if field_type.is(field_type_marker) { - let field_name: &PyString = field_name.downcast()?; - dict.set_item(field_name, dc.getattr(field_name)?)?; - } - } - Ok(dict) -} diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index a82643186..93548bcba 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyType}; use std::borrow::Cow; use ahash::AHashMap; +use serde::ser::SerializeMap; use crate::build_tools::{py_schema_error_type, ExtraBehavior}; use crate::definitions::DefinitionsBuilder; @@ -131,16 +132,30 @@ impl TypeSerializer for DataclassSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let extra = Extra { + let dc_extra = Extra { model: Some(value), ..*extra }; - if self.allow_value(value, &extra)? { - let inner_value = self.get_inner_value(value)?; - self.serializer.to_python(inner_value, include, exclude, &extra) + if self.allow_value(value, extra)? { + let py = value.py(); + if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { + let output_dict = fields_serializer.main_to_python( + py, + known_dataclass_iter(&self.fields, value), + include, + exclude, + dc_extra, + )?; + + fields_serializer.add_computed_fields_python(Some(value), output_dict, include, exclude, extra)?; + Ok(output_dict.into_py(py)) + } else { + let inner_value = self.get_inner_value(value)?; + self.serializer.to_python(inner_value, include, exclude, &dc_extra) + } } else { - extra.warnings.on_fallback_py(self.get_name(), value, &extra)?; - infer_to_python(value, include, exclude, &extra) + extra.warnings.on_fallback_py(self.get_name(), value, &dc_extra)?; + infer_to_python(value, include, exclude, &dc_extra) } } @@ -161,17 +176,29 @@ impl TypeSerializer for DataclassSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let extra = Extra { - model: Some(value), - ..*extra - }; - if self.allow_value(value, &extra).map_err(py_err_se_err)? { - let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; - self.serializer - .serde_serialize(inner_value, serializer, include, exclude, &extra) + let model = Some(value); + let dc_extra = Extra { model, ..*extra }; + if self.allow_value(value, extra).map_err(py_err_se_err)? { + if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer { + let expected_len = self.fields.len() + fields_serializer.computed_field_count(); + let mut map = fields_serializer.main_serde_serialize( + known_dataclass_iter(&self.fields, value), + expected_len, + serializer, + include, + exclude, + dc_extra, + )?; + fields_serializer.add_computed_fields_json::(model, &mut map, include, exclude, extra)?; + map.end() + } else { + let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; + self.serializer + .serde_serialize(inner_value, serializer, include, exclude, extra) + } } else { - extra.warnings.on_fallback_ser::(self.get_name(), value, &extra)?; - infer_serialize(value, serializer, include, exclude, &extra) + extra.warnings.on_fallback_ser::(self.get_name(), value, extra)?; + infer_serialize(value, serializer, include, exclude, extra) } } @@ -183,3 +210,18 @@ impl TypeSerializer for DataclassSerializer { true } } + +fn known_dataclass_iter<'a, 'py>( + fields: &'a [Py], + dataclass: &'py PyAny, +) -> impl Iterator> + 'a +where + 'py: 'a, +{ + let py = dataclass.py(); + fields.iter().map(move |field| { + let field_ref = field.clone_ref(py).into_ref(py); + let value = dataclass.getattr(field_ref)?; + Ok((field_ref as &PyAny, value)) + }) +} diff --git a/tests/benchmarks/test_serialization_micro.py b/tests/benchmarks/test_serialization_micro.py index 9200fc24b..96170b5eb 100644 --- a/tests/benchmarks/test_serialization_micro.py +++ b/tests/benchmarks/test_serialization_micro.py @@ -1,4 +1,5 @@ import json +from dataclasses import dataclass from datetime import date, datetime from uuid import UUID @@ -409,3 +410,48 @@ def test_ser_list_of_lists(benchmark): data = [[i + j for j in range(10)] for i in range(1000)] benchmark(s.to_json, data) + + +@dataclass +class Foo: + a: str + b: bytes + c: int + d: float + + +dataclass_schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()), + core_schema.dataclass_field(name='c', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='d', schema=core_schema.float_schema()), + ], + ), + ['a', 'b', 'c', 'd'], +) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_serialization_python(benchmark): + s = SchemaSerializer(dataclass_schema) + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + assert s.to_python(dc) == {'a': 'hello', 'b': b'more', 'c': 123, 'd': 1.23} + benchmark(s.to_python, dc) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_serialization_json(benchmark): + s = SchemaSerializer(dataclass_schema) + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + assert s.to_python(dc) == {'a': 'hello', 'b': b'more', 'c': 123, 'd': 1.23} + benchmark(s.to_json, dc) + + +@pytest.mark.benchmark(group='dataclass-ser') +def test_dataclass_to_json(benchmark): + dc = Foo(a='hello', b=b'more', c=123, d=1.23) + benchmark(to_json, dc) From 29c541917ad31595a938bf4e96654e29b4f50f7f Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 17 Jan 2024 07:27:14 -0700 Subject: [PATCH 188/550] Add support for dataclass fields init (#1163) Co-authored-by: sydney-runkle Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Co-authored-by: David Hewitt --- python/pydantic_core/core_schema.py | 4 + src/validators/dataclass.rs | 19 +++++ tests/validators/test_dataclasses.py | 121 +++++++++++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 44f58c48a..31bf48782 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -2985,6 +2985,7 @@ class DataclassField(TypedDict, total=False): name: Required[str] schema: Required[CoreSchema] kw_only: bool # default: True + init: bool # default: True init_only: bool # default: False frozen: bool # default: False validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] @@ -2998,6 +2999,7 @@ def dataclass_field( schema: CoreSchema, *, kw_only: bool | None = None, + init: bool | None = None, init_only: bool | None = None, validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, @@ -3023,6 +3025,7 @@ def dataclass_field( name: The name to use for the argument parameter schema: The schema to use for the argument parameter kw_only: Whether the field can be set with a positional argument as well as a keyword argument + init: Whether the field should be validated during initialization init_only: Whether the field should be omitted from `__dict__` and passed to `__post_init__` validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing @@ -3035,6 +3038,7 @@ def dataclass_field( name=name, schema=schema, kw_only=kw_only, + init=init, init_only=init_only, validation_alias=validation_alias, serialization_alias=serialization_alias, diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 986556b66..3161541f0 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -26,6 +26,7 @@ struct Field { kw_only: bool, name: String, py_name: Py, + init: bool, init_only: bool, lookup_key: LookupKey, validator: CombinedValidator, @@ -107,6 +108,7 @@ impl BuildValidator for DataclassArgsValidator { py_name: py_name.into(), lookup_key, validator, + init: field.get_as(intern!(py, "init"))?.unwrap_or(true), init_only: field.get_as(intern!(py, "init_only"))?.unwrap_or(false), frozen: field.get_as::(intern!(py, "frozen"))?.unwrap_or(false), }); @@ -176,6 +178,23 @@ impl Validator for DataclassArgsValidator { ($args:ident, $get_method:ident, $get_macro:ident, $slice_macro:ident) => {{ // go through fields getting the value from args or kwargs and validating it for (index, field) in self.fields.iter().enumerate() { + if (!field.init) { + match field.validator.default_value(py, Some(field.name.as_str()), state) { + Ok(Some(value)) => { + // Default value exists, and passed validation if required + set_item!(field, value); + }, + Ok(None) | Err(ValError::Omit) => continue, + // Note: this will always use the field name even if there is an alias + // However, we don't mind so much because this error can only happen if the + // default value fails validation, which is arguably a developer error. + // We could try to "fix" this in the future if desired. + Err(ValError::LineErrors(line_errors)) => errors.extend(line_errors), + Err(err) => return Err(err), + }; + continue; + }; + let mut pos_value = None; if let Some(args) = $args.args { if !field.kw_only { diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 94bf7fd65..a9b367008 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1592,3 +1592,124 @@ def _wrap_validator(cls, v, validator, info): gc.collect() assert ref() is None + + +init_test_cases = [ + ({'a': 'hello', 'b': 'bye'}, 'ignore', {'a': 'hello', 'b': 'HELLO'}), + ({'a': 'hello'}, 'ignore', {'a': 'hello', 'b': 'HELLO'}), + # note, for the case below, we don't actually support this case in Pydantic + # it's disallowed in Pydantic to have a model with extra='allow' and a field + # with init=False, so this case isn't really possible at the momment + # however, no conflict arises here because we don't pass in the value for b + # to __init__ + ({'a': 'hello'}, 'allow', {'a': 'hello', 'b': 'HELLO'}), + ( + {'a': 'hello', 'b': 'bye'}, + 'forbid', + Err( + 'Unexpected keyword argument', + errors=[ + { + 'type': 'unexpected_keyword_argument', + 'loc': ('b',), + 'msg': 'Unexpected keyword argument', + 'input': 'bye', + } + ], + ), + ), + ({'a': 'hello'}, 'forbid', {'a': 'hello', 'b': 'HELLO'}), +] + + +@pytest.mark.parametrize( + 'input_value,extra_behavior,expected', + [ + *init_test_cases, + # special case - when init=False, extra='allow', and the value is provided + # currently, it's disallowed in Pydantic to have a model with extra='allow' + # and a field with init=False, so this case isn't really possible at the momment + # TODO: open to changing this behavior, and changes won't be significantly breaking + # because we currently don't support this case + ({'a': 'hello', 'b': 'bye'}, 'allow', {'a': 'hello', 'b': 'HELLO'}), + ], +) +def test_dataclass_args_init(input_value, extra_behavior, expected): + @dataclasses.dataclass + class Foo: + a: str + b: str + + def __post_init__(self): + self.b = self.a.upper() + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.str_schema(), init=False), + ], + extra_behavior=extra_behavior, + ), + ['a', 'b'], + post_init=True, + ) + + v = SchemaValidator(schema) + + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)) as exc_info: + v.validate_python(input_value) + + if expected.errors is not None: + assert exc_info.value.errors(include_url=False) == expected.errors + else: + assert dataclasses.asdict(v.validate_python(input_value)) == expected + + +@pytest.mark.parametrize( + 'input_value,extra_behavior,expected', + [ + *init_test_cases, + # special case - allow override of default, even when init=False, if extra='allow' + # TODO: we haven't really decided if this should be allowed or not + # currently, it's disallowed in Pydantic to have a model with extra='allow' + # and a field with init=False, so this case isn't really possible at the momment + ({'a': 'hello', 'b': 'bye'}, 'allow', {'a': 'hello', 'b': 'bye'}), + ], +) +def test_dataclass_args_init_with_default(input_value, extra_behavior, expected): + @dataclasses.dataclass + class Foo: + a: str + b: str + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field( + name='b', + schema=core_schema.with_default_schema(schema=core_schema.str_schema(), default='HELLO'), + init=False, + ), + ], + extra_behavior=extra_behavior, + ), + ['a', 'b'], + ) + + v = SchemaValidator(schema) + + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=re.escape(expected.message)) as exc_info: + v.validate_python(input_value) + + if expected.errors is not None: + assert exc_info.value.errors(include_url=False) == expected.errors + else: + assert dataclasses.asdict(v.validate_python(input_value)) == expected From 4da7192ffc104cd6c424d102bc1c9c5bfdad543e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 17 Jan 2024 20:02:10 +0000 Subject: [PATCH 189/550] uprev to jiter v0.0.6, uprev pydantic-core (#1165) --- Cargo.lock | 21 ++++++++++++++++----- Cargo.toml | 4 ++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 34b46ee65..a2993621f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,6 +24,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "autocfg" version = "1.1.0" @@ -94,9 +100,13 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -138,11 +148,12 @@ checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "jiter" -version = "0.0.5" +version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e184598fea113663dd78e33a24ad3a1e7ba8ceedf71effb7406b3f2eccb63ed1" +checksum = "87db066a99f69382be06d02313f8ce989996b53a04a8a70cfd1a6483a56227f7" dependencies = [ "ahash", + "hashbrown", "lexical-core", "num-bigint", "num-traits", @@ -321,7 +332,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.15.0" +version = "2.16.0" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index 283e364f9..a878fcfd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.15.0" +version = "2.16.0" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" @@ -43,7 +43,7 @@ base64 = "0.21.7" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.6.1" -jiter = {version = "0.0.5", features = ["python"]} +jiter = {version = "0.0.6", features = ["python"]} [lib] name = "_pydantic_core" From 7a5f8e6b9a9951dd674a9841e1b463aec7a7a281 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Fri, 19 Jan 2024 12:10:36 +0000 Subject: [PATCH 190/550] Ensure recursion guard is always used as a stack (#1166) --- src/recursion_guard.rs | 63 +++++++++++++++++-- src/serializers/extra.rs | 56 +++++++++-------- src/serializers/infer.rs | 29 ++++----- src/serializers/mod.rs | 8 +-- .../type_serializers/definitions.rs | 19 +++--- src/validators/definitions.rs | 33 +++------- src/validators/generator.rs | 4 +- src/validators/mod.rs | 10 +-- src/validators/validation_state.rs | 12 +++- 9 files changed, 137 insertions(+), 97 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index fe5b1bcdd..d3302c0e9 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -12,8 +12,59 @@ type RecursionKey = ( /// This is used to avoid cyclic references in input data causing recursive validation and a nasty segmentation fault. /// It's used in `validators/definition` to detect when a reference is reused within itself. +pub(crate) struct RecursionGuard<'a, S: ContainsRecursionState> { + state: &'a mut S, + obj_id: usize, + node_id: usize, +} + +pub(crate) enum RecursionError { + /// Cyclic reference detected + Cyclic, + /// Recursion limit exceeded + Depth, +} + +impl RecursionGuard<'_, S> { + /// Creates a recursion guard for the given object and node id. + /// + /// When dropped, this will release the recursion for the given object and node id. + pub fn new(state: &'_ mut S, obj_id: usize, node_id: usize) -> Result, RecursionError> { + state.access_recursion_state(|state| { + if !state.insert(obj_id, node_id) { + return Err(RecursionError::Cyclic); + } + if state.incr_depth() { + return Err(RecursionError::Depth); + } + Ok(()) + })?; + Ok(RecursionGuard { state, obj_id, node_id }) + } + + /// Retrieves the underlying state for further use. + pub fn state(&mut self) -> &mut S { + self.state + } +} + +impl Drop for RecursionGuard<'_, S> { + fn drop(&mut self) { + self.state.access_recursion_state(|state| { + state.decr_depth(); + state.remove(self.obj_id, self.node_id); + }); + } +} + +/// This trait is used to retrieve the recursion state from some other type +pub(crate) trait ContainsRecursionState { + fn access_recursion_state(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R; +} + +/// State for the RecursionGuard. Can also be used directly to increase / decrease depth. #[derive(Debug, Clone, Default)] -pub struct RecursionGuard { +pub struct RecursionState { ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators @@ -31,11 +82,11 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi 255 }; -impl RecursionGuard { +impl RecursionState { // insert a new value // * return `false` if the stack already had it in it // * return `true` if the stack didn't have it in it and it was inserted - pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { self.ids.insert((obj_id, node_id)) } @@ -68,7 +119,7 @@ impl RecursionGuard { self.depth = self.depth.saturating_sub(1); } - pub fn remove(&mut self, obj_id: usize, node_id: usize) { + fn remove(&mut self, obj_id: usize, node_id: usize) { self.ids.remove(&(obj_id, node_id)); } } @@ -98,7 +149,7 @@ impl RecursionStack { // insert a new value // * return `false` if the stack already had it in it // * return `true` if the stack didn't have it in it and it was inserted - pub fn insert(&mut self, v: RecursionKey) -> bool { + fn insert(&mut self, v: RecursionKey) -> bool { match self { Self::Array { data, len } => { if *len < ARRAY_SIZE { @@ -129,7 +180,7 @@ impl RecursionStack { } } - pub fn remove(&mut self, v: &RecursionKey) { + fn remove(&mut self, v: &RecursionKey) { match self { Self::Array { data, len } => { *len = len.checked_sub(1).expect("remove from empty recursion guard"); diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index b3978613a..8d598d46b 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -10,20 +10,23 @@ use serde::ser::Error; use super::config::SerializationConfig; use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER}; use super::ob_type::ObTypeLookup; +use crate::recursion_guard::ContainsRecursionState; +use crate::recursion_guard::RecursionError; use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; /// this is ugly, would be much better if extra could be stored in `SerializationState` /// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work pub(crate) struct SerializationState { warnings: CollectWarnings, - rec_guard: SerRecursionGuard, + rec_guard: SerRecursionState, config: SerializationConfig, } impl SerializationState { pub fn new(timedelta_mode: &str, bytes_mode: &str, inf_nan_mode: &str) -> PyResult { let warnings = CollectWarnings::new(false); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let config = SerializationConfig::from_args(timedelta_mode, bytes_mode, inf_nan_mode)?; Ok(Self { warnings, @@ -77,7 +80,7 @@ pub(crate) struct Extra<'a> { pub exclude_none: bool, pub round_trip: bool, pub config: &'a SerializationConfig, - pub rec_guard: &'a SerRecursionGuard, + pub rec_guard: &'a SerRecursionState, // the next two are used for union logic pub check: SerCheck, // data representing the current model field @@ -101,7 +104,7 @@ impl<'a> Extra<'a> { exclude_none: bool, round_trip: bool, config: &'a SerializationConfig, - rec_guard: &'a SerRecursionGuard, + rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, ) -> Self { @@ -124,6 +127,22 @@ impl<'a> Extra<'a> { } } + pub fn recursion_guard<'x, 'y>( + // TODO: this double reference is a bit if a hack, but it's necessary because the recursion + // guard is not passed around with &mut reference + // + // See how validation has &mut ValidationState passed around; we should aim to refactor + // to match that. + self: &'x mut &'y Self, + value: &PyAny, + def_ref_id: usize, + ) -> PyResult> { + RecursionGuard::new(self, value.as_ptr() as usize, def_ref_id).map_err(|e| match e { + RecursionError::Depth => PyValueError::new_err("Circular reference detected (depth exceeded)"), + RecursionError::Cyclic => PyValueError::new_err("Circular reference detected (id repeated)"), + }) + } + pub fn serialize_infer<'py>(&'py self, value: &'py PyAny) -> super::infer::SerializeInfer<'py> { super::infer::SerializeInfer::new(value, None, None, self) } @@ -157,7 +176,7 @@ pub(crate) struct ExtraOwned { exclude_none: bool, round_trip: bool, config: SerializationConfig, - rec_guard: SerRecursionGuard, + rec_guard: SerRecursionState, check: SerCheck, model: Option, field_name: Option, @@ -340,29 +359,12 @@ impl CollectWarnings { #[derive(Default, Clone)] #[cfg_attr(debug_assertions, derive(Debug))] -pub struct SerRecursionGuard { - guard: RefCell, +pub struct SerRecursionState { + guard: RefCell, } -impl SerRecursionGuard { - pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { - let id = value.as_ptr() as usize; - let mut guard = self.guard.borrow_mut(); - - if guard.insert(id, def_ref_id) { - if guard.incr_depth() { - Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) - } else { - Ok(id) - } - } else { - Err(PyValueError::new_err("Circular reference detected (id repeated)")) - } - } - - pub fn pop(&self, id: usize, def_ref_id: usize) { - let mut guard = self.guard.borrow_mut(); - guard.decr_depth(); - guard.remove(id, def_ref_id); +impl ContainsRecursionState for &'_ Extra<'_> { + fn access_recursion_state(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R { + f(&mut self.rec_guard.guard.borrow_mut()) } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 5ddf77597..487d0d091 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -40,19 +40,22 @@ pub(crate) fn infer_to_python_known( value: &PyAny, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> PyResult { let py = value.py(); - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { - Ok(id) => id, + + let mode = extra.mode; + let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) { + Ok(v) => v, Err(e) => { - return match extra.mode { + return match mode { SerMode::Json => Err(e), // if recursion is detected by we're serializing to python, we just return the value _ => Ok(value.into_py(py)), }; } }; + let extra = guard.state(); macro_rules! serialize_seq { ($t:ty) => { @@ -220,7 +223,6 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serialize_unknown(value).into_py(py) @@ -267,7 +269,6 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } value.into_py(py) @@ -275,7 +276,6 @@ pub(crate) fn infer_to_python_known( _ => value.into_py(py), }, }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); Ok(value) } @@ -332,18 +332,21 @@ pub(crate) fn infer_serialize_known( serializer: S, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> Result { - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { + let extra_serialize_unknown = extra.serialize_unknown; + let mut guard = match extra.recursion_guard(value, INFER_DEF_REF_ID) { Ok(v) => v, Err(e) => { - return if extra.serialize_unknown { + return if extra_serialize_unknown { serializer.serialize_str("...") } else { - Err(e) - } + Err(py_err_se_err(e)) + }; } }; + let extra = guard.state(); + macro_rules! serialize { ($t:ty) => { match value.extract::<$t>() { @@ -506,7 +509,6 @@ pub(crate) fn infer_serialize_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; let next_result = infer_serialize(next_value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) @@ -520,7 +522,6 @@ pub(crate) fn infer_serialize_known( } } }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); ser_result } diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8159691cb..7d9c5347c 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -10,7 +10,7 @@ use crate::py_gc::PyGcTraverse; use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; -use extra::{CollectWarnings, SerRecursionGuard}; +use extra::{CollectWarnings, SerRecursionState}; pub(crate) use extra::{Extra, SerMode, SerializationState}; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; @@ -52,7 +52,7 @@ impl SchemaSerializer { exclude_defaults: bool, exclude_none: bool, round_trip: bool, - rec_guard: &'a SerRecursionGuard, + rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, ) -> Extra<'b> { @@ -113,7 +113,7 @@ impl SchemaSerializer { ) -> PyResult { let mode: SerMode = mode.into(); let warnings = CollectWarnings::new(warnings); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let extra = self.build_extra( py, &mode, @@ -152,7 +152,7 @@ impl SchemaSerializer { fallback: Option<&PyAny>, ) -> PyResult { let warnings = CollectWarnings::new(warnings); - let rec_guard = SerRecursionGuard::default(); + let rec_guard = SerRecursionState::default(); let extra = self.build_extra( py, &SerMode::Json, diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 99dae5bcd..2f98a94e0 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -66,14 +66,12 @@ impl TypeSerializer for DefinitionRefSerializer { value: &PyAny, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> PyResult { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra.rec_guard.add(value, self.definition.id())?; - let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + let mut guard = extra.recursion_guard(value, self.definition.id())?; + comb_serializer.to_python(value, include, exclude, guard.state()) }) } @@ -87,17 +85,14 @@ impl TypeSerializer for DefinitionRefSerializer { serializer: S, include: Option<&PyAny>, exclude: Option<&PyAny>, - extra: &Extra, + mut extra: &Extra, ) -> Result { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra - .rec_guard - .add(value, self.definition.id()) + let mut guard = extra + .recursion_guard(value, self.definition.id()) .map_err(py_err_se_err)?; - let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + comb_serializer.serde_serialize(value, serializer, include, exclude, guard.state()) }) } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index e8c67a690..e4bc270c2 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -6,6 +6,7 @@ use crate::definitions::DefinitionRef; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; use crate::input::Input; +use crate::recursion_guard::RecursionGuard; use crate::tools::SchemaDict; use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; @@ -76,18 +77,10 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.insert(id, self.definition.id()) { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); - } - let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output - } else { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } + let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + }; + validator.validate(py, input, guard.state()) } else { validator.validate(py, input, state) } @@ -105,18 +98,10 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.insert(id, self.definition.id()) { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); - } - let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output - } else { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } + let Ok(mut guard) = RecursionGuard::new(state, id, self.definition.id()) else { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + }; + validator.validate_assignment(py, obj, field_name, field_value, guard.state()) } else { validator.validate_assignment(py, obj, field_name, field_value, state) } diff --git a/src/validators/generator.rs b/src/validators/generator.rs index 3b5fedd97..a366da431 100644 --- a/src/validators/generator.rs +++ b/src/validators/generator.rs @@ -6,7 +6,7 @@ use pyo3::types::PyDict; use crate::errors::{ErrorType, LocItem, ValError, ValResult}; use crate::input::{GenericIterator, Input}; -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; use crate::tools::SchemaDict; use crate::ValidationError; @@ -212,7 +212,7 @@ pub struct InternalValidator { from_attributes: Option, context: Option, self_instance: Option, - recursion_guard: RecursionGuard, + recursion_guard: RecursionState, pub(crate) exactness: Option, validation_mode: InputType, hide_input_in_errors: bool, diff --git a/src/validators/mod.rs b/src/validators/mod.rs index fc1584080..496da5c73 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -13,7 +13,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::errors::{LocItem, ValError, ValResult, ValidationError}; use crate::input::{Input, InputType, StringMapping}; use crate::py_gc::PyGcTraverse; -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::RecursionState; use crate::tools::SchemaDict; mod any; @@ -263,7 +263,7 @@ impl SchemaValidator { self_instance: None, }; - let guard = &mut RecursionGuard::default(); + let guard = &mut RecursionState::default(); let mut state = ValidationState::new(extra, guard); self.validator .validate_assignment(py, obj, field_name, field_value, &mut state) @@ -280,7 +280,7 @@ impl SchemaValidator { context, self_instance: None, }; - let recursion_guard = &mut RecursionGuard::default(); + let recursion_guard = &mut RecursionState::default(); let mut state = ValidationState::new(extra, recursion_guard); let r = self.validator.default_value(py, None::, &mut state); match r { @@ -326,7 +326,7 @@ impl SchemaValidator { where 's: 'data, { - let mut recursion_guard = RecursionGuard::default(); + let mut recursion_guard = RecursionState::default(); let mut state = ValidationState::new( Extra::new(strict, from_attributes, context, self_instance, input_type), &mut recursion_guard, @@ -378,7 +378,7 @@ impl<'py> SelfValidator<'py> { } pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny, strict: Option) -> PyResult<&'py PyAny> { - let mut recursion_guard = RecursionGuard::default(); + let mut recursion_guard = RecursionState::default(); let mut state = ValidationState::new( Extra::new(strict, None, None, None, InputType::Python), &mut recursion_guard, diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index aacd7d2af..4f241d768 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -1,4 +1,4 @@ -use crate::recursion_guard::RecursionGuard; +use crate::recursion_guard::{ContainsRecursionState, RecursionState}; use super::Extra; @@ -10,14 +10,14 @@ pub enum Exactness { } pub struct ValidationState<'a> { - pub recursion_guard: &'a mut RecursionGuard, + pub recursion_guard: &'a mut RecursionState, pub exactness: Option, // deliberately make Extra readonly extra: Extra<'a>, } impl<'a> ValidationState<'a> { - pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionGuard) -> Self { + pub fn new(extra: Extra<'a>, recursion_guard: &'a mut RecursionState) -> Self { Self { recursion_guard, // Don't care about exactness unless doing union validation exactness: None, @@ -84,6 +84,12 @@ impl<'a> ValidationState<'a> { } } +impl ContainsRecursionState for ValidationState<'_> { + fn access_recursion_state(&mut self, f: impl FnOnce(&mut RecursionState) -> R) -> R { + f(self.recursion_guard) + } +} + pub struct ValidationStateWithReboundExtra<'state, 'a> { state: &'state mut ValidationState<'a>, old_extra: Extra<'a>, From 4538190f0e7a47a99ca44351f744007b016511d4 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Fri, 19 Jan 2024 06:46:11 -0600 Subject: [PATCH 191/550] Uprev core to 2.16.1 (#1167) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a2993621f..5d113589f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.16.0" +version = "2.16.1" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index a878fcfd4..c26c91419 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.16.0" +version = "2.16.1" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From c670b2bd729d2b91eca5b24b2fe3d2785e3e8cd2 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 22 Jan 2024 11:46:34 +0000 Subject: [PATCH 192/550] remove `AsLocItem` trait (#1169) --- src/errors/line_error.rs | 7 +++-- src/errors/location.rs | 16 +++++------ src/errors/mod.rs | 2 +- src/input/input_abstract.rs | 4 +-- src/input/input_json.rs | 40 ++++++++++----------------- src/input/input_python.rs | 50 +++++++++++++++++----------------- src/input/input_string.rs | 14 +++++----- src/input/return_enums.rs | 4 +-- src/lookup_key.rs | 4 +-- src/validators/arguments.rs | 12 ++++---- src/validators/call.rs | 2 +- src/validators/dataclass.rs | 14 ++++------ src/validators/dict.rs | 16 ++++++----- src/validators/function.rs | 6 ++-- src/validators/model_fields.rs | 12 ++++---- src/validators/tuple.rs | 6 ++-- src/validators/typed_dict.rs | 12 ++++---- src/validators/union.rs | 6 ++-- src/validators/with_default.rs | 2 +- 19 files changed, 108 insertions(+), 121 deletions(-) diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index f5fa11b70..2a48bfaf4 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -61,7 +61,8 @@ impl ValError { } /// helper function to call with_outer on line items if applicable - pub fn with_outer_location(self, loc_item: LocItem) -> Self { + pub fn with_outer_location(self, into_loc_item: impl Into) -> Self { + let loc_item = into_loc_item.into(); match self { Self::LineErrors(mut line_errors) => { for line_error in &mut line_errors { @@ -120,8 +121,8 @@ impl ValLineError { /// location is stored reversed so it's quicker to add "outer" items as that's what we always do /// hence `push` here instead of `insert` - pub fn with_outer_location(mut self, loc_item: LocItem) -> Self { - self.location.with_outer(loc_item); + pub fn with_outer_location(mut self, into_loc_item: impl Into) -> Self { + self.location.with_outer(into_loc_item.into()); self } diff --git a/src/errors/location.rs b/src/errors/location.rs index 8acc2a039..55bab0017 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -34,18 +34,18 @@ impl fmt::Display for LocItem { } } -// TODO rename to ToLocItem -pub trait AsLocItem { - // TODO rename to to_loc_item - fn as_loc_item(&self) -> LocItem; -} - impl From for LocItem { fn from(s: String) -> Self { Self::S(s) } } +impl From<&String> for LocItem { + fn from(s: &String) -> Self { + s.to_string().into() + } +} + impl From<&str> for LocItem { fn from(s: &str) -> Self { Self::S(s.to_string()) @@ -201,9 +201,9 @@ impl TryFrom> for Location { fn try_from(location: Option<&PyAny>) -> PyResult { if let Some(location) = location { let mut loc_vec: Vec = if let Ok(tuple) = location.downcast::() { - tuple.iter().map(AsLocItem::as_loc_item).collect() + tuple.iter().map(Into::into).collect() } else if let Ok(list) = location.downcast::() { - list.iter().map(AsLocItem::as_loc_item).collect() + list.iter().map(Into::into).collect() } else { return Err(PyTypeError::new_err( "Location must be a list or tuple of strings and ints", diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 131e54177..de6650527 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -7,7 +7,7 @@ mod validation_exception; mod value_exception; pub use self::line_error::{AsErrorValue, InputValue, ValError, ValLineError, ValResult}; -pub use self::location::{AsLocItem, LocItem}; +pub use self::location::LocItem; pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number}; pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 8229677b6..6ce1479f6 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -4,7 +4,7 @@ use pyo3::exceptions::PyValueError; use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; -use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult}; +use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::py_err; use crate::{PyMultiHostUrl, PyUrl}; @@ -46,7 +46,7 @@ impl TryFrom<&str> for InputType { /// the convention is to either implement: /// * `strict_*` & `lax_*` if they have different behavior /// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same -pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized { +pub trait Input<'a>: fmt::Debug + ToPyObject + Into + Sized { fn as_error_value(&self) -> InputValue; fn identity(&self) -> Option { diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 195f9caba..cd4ac919c 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -6,7 +6,7 @@ use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; use strum::EnumMessage; -use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::validators::decimal::create_decimal; use super::datetime::{ @@ -21,9 +21,9 @@ use super::{ }; /// This is required but since JSON object keys are always strings, I don't think it can be called -impl AsLocItem for JsonValue { - fn as_loc_item(&self) -> LocItem { - match self { +impl From<&JsonValue> for LocItem { + fn from(json_value: &JsonValue) -> Self { + match json_value { JsonValue::Int(i) => (*i).into(), JsonValue::Str(s) => s.as_str().into(), v => format!("{v:?}").into(), @@ -31,9 +31,9 @@ impl AsLocItem for JsonValue { } } -impl AsLocItem for &JsonValue { - fn as_loc_item(&self) -> LocItem { - AsLocItem::as_loc_item(*self) +impl From for LocItem { + fn from(json_value: JsonValue) -> Self { + (&json_value).into() } } @@ -84,13 +84,6 @@ impl<'a> Input<'a> for JsonValue { } } - fn exact_str(&'a self) -> ValResult> { - match self { - JsonValue::Str(s) => Ok(s.as_str().into()), - _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), - } - } - fn validate_str( &'a self, strict: bool, @@ -144,6 +137,13 @@ impl<'a> Input<'a> for JsonValue { } } + fn exact_str(&'a self) -> ValResult> { + match self { + JsonValue::Str(s) => Ok(s.as_str().into()), + _ => Err(ValError::new(ErrorTypeDefaults::StringType, self)), + } + } + fn validate_float(&'a self, strict: bool) -> ValResult>> { match self { JsonValue::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))), @@ -319,18 +319,6 @@ impl BorrowInput for &'_ JsonValue { } } -impl AsLocItem for String { - fn as_loc_item(&self) -> LocItem { - self.to_string().into() - } -} - -impl AsLocItem for &String { - fn as_loc_item(&self) -> LocItem { - AsLocItem::as_loc_item(*self) - } -} - /// Required for JSON Object keys so the string can behave like an Input impl<'a> Input<'a> for String { fn as_error_value(&self) -> InputValue { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 2eaaebe10..fd93820ef 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -12,7 +12,7 @@ use pyo3::{intern, PyTypeInfo}; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::tools::{extract_i64, safe_repr}; use crate::validators::decimal::{create_decimal, get_decimal_type}; use crate::validators::Exactness; @@ -92,21 +92,21 @@ macro_rules! extract_dict_items { }; } -impl AsLocItem for PyAny { - fn as_loc_item(&self) -> LocItem { - if let Ok(py_str) = self.downcast::() { +impl From<&PyAny> for LocItem { + fn from(py_any: &PyAny) -> Self { + if let Ok(py_str) = py_any.downcast::() { py_str.to_string_lossy().as_ref().into() - } else if let Some(key_int) = extract_i64(self) { + } else if let Some(key_int) = extract_i64(py_any) { key_int.into() } else { - safe_repr(self).to_string().into() + safe_repr(py_any).to_string().into() } } } -impl AsLocItem for &'_ PyAny { - fn as_loc_item(&self) -> LocItem { - AsLocItem::as_loc_item(*self) +impl From for LocItem { + fn from(py_any: PyAny) -> Self { + (&py_any).into() } } @@ -244,22 +244,6 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::StringType, self)) } - fn exact_str(&'a self) -> ValResult> { - if let Ok(py_str) = ::try_from_exact(self) { - Ok(EitherString::Py(py_str)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - - fn exact_int(&'a self) -> ValResult> { - if PyInt::is_exact_type_of(self) { - Ok(EitherInt::Py(self)) - } else { - Err(ValError::new(ErrorTypeDefaults::IntType, self)) - } - } - fn validate_bytes(&'a self, strict: bool) -> ValResult>> { if let Ok(py_bytes) = self.downcast_exact::() { return Ok(ValidationMatch::exact(py_bytes.into())); @@ -347,6 +331,22 @@ impl<'a> Input<'a> for PyAny { Err(ValError::new(ErrorTypeDefaults::IntType, self)) } + fn exact_int(&'a self) -> ValResult> { + if PyInt::is_exact_type_of(self) { + Ok(EitherInt::Py(self)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + + fn exact_str(&'a self) -> ValResult> { + if let Ok(py_str) = ::try_from_exact(self) { + Ok(EitherString::Py(py_str)) + } else { + Err(ValError::new(ErrorTypeDefaults::IntType, self)) + } + } + fn validate_float(&'a self, strict: bool) -> ValResult>> { if let Ok(float) = self.downcast_exact::() { return Ok(ValidationMatch::exact(EitherFloat::Py(float))); diff --git a/src/input/input_string.rs b/src/input/input_string.rs index 0c2c9a8ca..cd5931b24 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -3,7 +3,7 @@ use pyo3::types::{PyDict, PyString}; use speedate::MicrosecondsPrecisionOverflowBehavior; -use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; +use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult}; use crate::input::py_string_str; use crate::tools::safe_repr; use crate::validators::decimal::create_decimal; @@ -17,7 +17,7 @@ use super::{ GenericIterator, GenericMapping, Input, ValidationMatch, }; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum StringMapping<'py> { String(&'py PyString), Mapping(&'py PyDict), @@ -52,11 +52,11 @@ impl<'py> StringMapping<'py> { } } -impl AsLocItem for StringMapping<'_> { - fn as_loc_item(&self) -> LocItem { - match self { - Self::String(s) => s.to_string_lossy().as_ref().into(), - Self::Mapping(d) => safe_repr(d).to_string().into(), +impl From> for LocItem { + fn from(string_mapping: StringMapping<'_>) -> Self { + match string_mapping { + StringMapping::String(s) => s.to_string_lossy().as_ref().into(), + StringMapping::Mapping(d) => safe_repr(d).to_string().into(), } } } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 905b895f9..b323383b7 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -205,7 +205,7 @@ fn validate_iter_to_vec<'a, 's>( } Err(ValError::LineErrors(line_errors)) => { max_length_check.incr()?; - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(ValError::Omit) => (), Err(err) => return Err(err), @@ -284,7 +284,7 @@ fn validate_iter_to_set<'a, 's>( } } Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(ValError::Omit) => (), Err(err) => return Err(err), diff --git a/src/lookup_key.rs b/src/lookup_key.rs index e4d2dcce7..d4fba403f 100644 --- a/src/lookup_key.rs +++ b/src/lookup_key.rs @@ -366,11 +366,11 @@ impl LookupPath { pub fn apply_error_loc(&self, mut line_error: ValLineError, loc_by_alias: bool, field_name: &str) -> ValLineError { if loc_by_alias { for path_item in self.iter().rev() { - line_error = line_error.with_outer_location(path_item.clone().into()); + line_error = line_error.with_outer_location(path_item.clone()); } line_error } else { - line_error.with_outer_location(field_name.to_string().into()) + line_error.with_outer_location(field_name) } } diff --git a/src/validators/arguments.rs b/src/validators/arguments.rs index a605ff999..661aaf1e4 100644 --- a/src/validators/arguments.rs +++ b/src/validators/arguments.rs @@ -6,7 +6,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{schema_or_config_same, ExtraBehavior}; -use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; @@ -209,7 +209,7 @@ impl Validator for ArgumentsValidator { { Ok(value) => output_args.push(value), Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(err) => return Err(err), } @@ -263,7 +263,7 @@ impl Validator for ArgumentsValidator { errors.extend( line_errors .into_iter() - .map(|err| err.with_outer_location((index + self.positional_params_count).into())), + .map(|err| err.with_outer_location(index + self.positional_params_count)), ); } Err(err) => return Err(err), @@ -289,7 +289,7 @@ impl Validator for ArgumentsValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { errors.push( - err.with_outer_location(raw_key.as_loc_item()) + err.with_outer_location(raw_key) .with_type(ErrorTypeDefaults::InvalidKey), ); } @@ -303,7 +303,7 @@ impl Validator for ArgumentsValidator { Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(raw_key.as_loc_item())); + errors.push(err.with_outer_location(raw_key)); } } Err(err) => return Err(err), @@ -313,7 +313,7 @@ impl Validator for ArgumentsValidator { errors.push(ValLineError::new_with_loc( ErrorTypeDefaults::UnexpectedKeywordArgument, value, - raw_key.as_loc_item(), + raw_key, )); } } diff --git a/src/validators/call.rs b/src/validators/call.rs index bf80415af..32b76f582 100644 --- a/src/validators/call.rs +++ b/src/validators/call.rs @@ -93,7 +93,7 @@ impl Validator for CallValidator { if let Some(return_validator) = &self.return_validator { return_validator .validate(py, return_value.into_ref(py), state) - .map_err(|e| e.with_outer_location("return".into())) + .map_err(|e| e.with_outer_location("return")) } else { Ok(return_value.to_object(py)) } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 3161541f0..bfeac2115 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::InputType; use crate::input::{BorrowInput, GenericArguments, Input, ValidationMatch}; use crate::lookup_key::LookupKey; @@ -231,7 +231,7 @@ impl Validator for DataclassArgsValidator { errors.extend( line_errors .into_iter() - .map(|err| err.with_outer_location(index.into())), + .map(|err| err.with_outer_location(index)), ); } Err(err) => return Err(err), @@ -310,7 +310,7 @@ impl Validator for DataclassArgsValidator { ValLineError::new_with_loc( ErrorTypeDefaults::UnexpectedKeywordArgument, value, - raw_key.as_loc_item(), + raw_key, ), ); } @@ -322,9 +322,7 @@ impl Validator for DataclassArgsValidator { .set_item(either_str.as_py_string(py), value)?, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location( - raw_key.as_loc_item(), - )); + errors.push(err.with_outer_location(raw_key)); } } Err(err) => return Err(err), @@ -339,7 +337,7 @@ impl Validator for DataclassArgsValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { errors.push( - err.with_outer_location(raw_key.as_loc_item()) + err.with_outer_location(raw_key) .with_type(ErrorTypeDefaults::InvalidKey), ); } @@ -430,7 +428,7 @@ impl Validator for DataclassArgsValidator { Err(ValError::LineErrors(line_errors)) => { let errors = line_errors .into_iter() - .map(|e| e.with_outer_location(field_name.into())) + .map(|e| e.with_outer_location(field_name)) .collect(); Err(ValError::LineErrors(errors)) } diff --git a/src/validators/dict.rs b/src/validators/dict.rs index d6507e8e9..3f8ff820b 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -3,7 +3,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use crate::build_tools::is_strict; -use crate::errors::{AsLocItem, ValError, ValLineError, ValResult}; +use crate::errors::{LocItem, ValError, ValLineError, ValResult}; use crate::input::BorrowInput; use crate::input::{ DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, @@ -103,7 +103,12 @@ impl DictValidator { &'s self, py: Python<'data>, input: &'data impl Input<'data>, - mapping_iter: impl Iterator>, + mapping_iter: impl Iterator< + Item = ValResult<( + impl BorrowInput + Clone + Into + 'data, + impl BorrowInput + 'data, + )>, + >, state: &mut ValidationState, ) -> ValResult { let output = PyDict::new(py); @@ -118,10 +123,7 @@ impl DictValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { // these are added in reverse order so [key] is shunted along by the second call - errors.push( - err.with_outer_location("[key]".into()) - .with_outer_location(key.as_loc_item()), - ); + errors.push(err.with_outer_location("[key]").with_outer_location(key.clone())); } None } @@ -132,7 +134,7 @@ impl DictValidator { Ok(value) => Some(value), Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(key.as_loc_item())); + errors.push(err.with_outer_location(key.clone())); } None } diff --git a/src/validators/function.rs b/src/validators/function.rs index e7134ab29..34e0b6327 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -6,7 +6,7 @@ use pyo3::types::{PyAny, PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use crate::errors::{ - AsLocItem, ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, + ErrorType, PydanticCustomError, PydanticKnownError, PydanticOmit, ValError, ValResult, ValidationError, }; use crate::input::Input; use crate::py_gc::PyGcTraverse; @@ -371,7 +371,7 @@ struct ValidatorCallable { #[pymethods] impl ValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = outer_location.map(AsLocItem::as_loc_item); + let outer_location = outer_location.map(Into::into); self.validator.validate(py, input_value, outer_location) } @@ -399,7 +399,7 @@ struct AssignmentValidatorCallable { #[pymethods] impl AssignmentValidatorCallable { fn __call__(&mut self, py: Python, input_value: &PyAny, outer_location: Option<&PyAny>) -> PyResult { - let outer_location = outer_location.map(AsLocItem::as_loc_item); + let outer_location = outer_location.map(Into::into); self.validator.validate_assignment( py, input_value, diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index c73a2821d..192d1ba8f 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -7,7 +7,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config_same, ExtraBehavior}; -use crate::errors::{AsLocItem, ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, @@ -183,7 +183,7 @@ impl Validator for ModelFieldsValidator { Ok(v) => v, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(field.name.as_loc_item())); + errors.push(err.with_outer_location(&field.name)); } continue; } @@ -256,7 +256,7 @@ impl Validator for ModelFieldsValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { errors.push( - err.with_outer_location(raw_key.as_loc_item()) + err.with_outer_location(raw_key.clone()) .with_type(ErrorTypeDefaults::InvalidKey) ); @@ -278,7 +278,7 @@ impl Validator for ModelFieldsValidator { ValLineError::new_with_loc( ErrorTypeDefaults::ExtraForbidden, value, - raw_key.as_loc_item(), + raw_key, ) ); @@ -294,7 +294,7 @@ impl Validator for ModelFieldsValidator { } Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(raw_key.as_loc_item())); + errors.push(err.with_outer_location(raw_key.clone())); } } Err(err) => return Err(err), @@ -355,7 +355,7 @@ impl Validator for ModelFieldsValidator { Err(ValError::LineErrors(line_errors)) => { let errors = line_errors .into_iter() - .map(|e| e.with_outer_location(field_name.to_string().into())) + .map(|e| e.with_outer_location(field_name)) .collect(); Err(ValError::LineErrors(errors)) } diff --git a/src/validators/tuple.rs b/src/validators/tuple.rs index db9c57953..6cad26073 100644 --- a/src/validators/tuple.rs +++ b/src/validators/tuple.rs @@ -76,7 +76,7 @@ impl TupleValidator { Some((index, input_item)) => match validator.validate(py, input_item, state) { Ok(item) => self.push_output_item(input, output, item, actual_length)?, Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(ValError::Omit) => (), Err(err) => return Err(err), @@ -136,7 +136,7 @@ impl TupleValidator { match variable_validator.validate(py, input_item, state) { Ok(item) => self.push_output_item(input, &mut output, item, actual_length)?, Err(ValError::LineErrors(line_errors)) => { - errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index.into()))); + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } Err(ValError::Omit) => (), Err(err) => return Err(err), @@ -167,7 +167,7 @@ impl TupleValidator { errors.extend( line_errors .into_iter() - .map(|err| err.with_outer_location(buffer_item_index.into())), + .map(|err| err.with_outer_location(buffer_item_index)), ); } Err(ValError::Omit) => (), diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index d2fc00ec7..ec0cc07cf 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -8,7 +8,7 @@ use ahash::AHashSet; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config, schema_or_config_same, ExtraBehavior}; -use crate::errors::{AsLocItem, ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; use crate::input::{ AttributesGenericIterator, BorrowInput, DictGenericIterator, GenericMapping, Input, JsonObjectGenericIterator, MappingGenericIterator, StringMappingGenericIterator, ValidationMatch, @@ -183,7 +183,7 @@ impl Validator for TypedDictValidator { Ok(v) => v, Err(ValError::LineErrors(line_errors)) => { for err in line_errors { - errors.push(err.with_outer_location(field.name.as_loc_item())); + errors.push(err.with_outer_location(&field.name)); } continue; } @@ -256,7 +256,7 @@ impl Validator for TypedDictValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { errors.push( - err.with_outer_location(raw_key.as_loc_item()) + err.with_outer_location(raw_key.clone()) .with_type(ErrorTypeDefaults::InvalidKey) ); @@ -278,7 +278,7 @@ impl Validator for TypedDictValidator { ValLineError::new_with_loc( ErrorTypeDefaults::ExtraForbidden, value, - raw_key.as_loc_item(), + raw_key, ) ); @@ -294,9 +294,7 @@ impl Validator for TypedDictValidator { Err(ValError::LineErrors(line_errors)) => { for err in line_errors { errors.push( - err - .with_outer_location(raw_key.as_loc_item()) - + err.with_outer_location(raw_key.clone()) ); } } diff --git a/src/validators/union.rs b/src/validators/union.rs index e010348af..0a5bd0c65 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -8,7 +8,7 @@ use smallvec::SmallVec; use crate::build_tools::py_schema_err; use crate::build_tools::{is_strict, schema_or_config}; -use crate::errors::{AsLocItem, ErrorType, ValError, ValLineError, ValResult}; +use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericMapping, Input}; use crate::lookup_key::LookupKey; use crate::py_gc::PyGcTraverse; @@ -263,7 +263,7 @@ impl<'a> MaybeErrors<'a> { }| { line_errors.into_iter().map(move |err| { let case_label = label.unwrap_or(choice.get_name()); - err.with_outer_location(case_label.into()) + err.with_outer_location(case_label) }) }, ) @@ -486,7 +486,7 @@ impl TaggedUnionValidator { if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) { return match validator.validate(py, input, state) { Ok(res) => Ok(res), - Err(err) => Err(err.with_outer_location(tag.as_loc_item())), + Err(err) => Err(err.with_outer_location(tag)), }; } match self.custom_error { diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index a17c2cefc..f520eaf38 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -170,7 +170,7 @@ impl Validator for WithDefaultValidator { Ok(v) => Ok(Some(v)), Err(e) => { if let Some(outer_loc) = outer_loc { - Err(e.with_outer_location(outer_loc.into())) + Err(e.with_outer_location(outer_loc)) } else { Err(e) } From 8be45e67aa0de48727331057366e2abc21cfc2c5 Mon Sep 17 00:00:00 2001 From: Smixi Date: Mon, 22 Jan 2024 14:09:25 +0100 Subject: [PATCH 193/550] fix: 8405 pattern serialization (#1168) --- src/serializers/infer.rs | 12 ++++++++++++ src/serializers/ob_type.rs | 10 ++++++++++ tests/serializers/test_any.py | 3 ++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 487d0d091..edbe614e4 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -219,6 +219,7 @@ pub(crate) fn infer_to_python_known( PyList::new(py, items).into_py(py) } ObType::Path => value.str()?.into_py(py), + ObType::Pattern => value.getattr(intern!(py, "pattern"))?.into_py(py), ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; @@ -505,6 +506,16 @@ pub(crate) fn infer_serialize_known( let s = value.str().map_err(py_err_se_err)?.to_str().map_err(py_err_se_err)?; serializer.serialize_str(s) } + ObType::Pattern => { + let s = value + .getattr(intern!(value.py(), "pattern")) + .map_err(py_err_se_err)? + .str() + .map_err(py_err_se_err)? + .to_str() + .map_err(py_err_se_err)?; + serializer.serialize_str(s) + } ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; @@ -628,6 +639,7 @@ pub(crate) fn infer_json_key_known<'py>(ob_type: ObType, key: &'py PyAny, extra: infer_json_key(k, extra) } ObType::Path => Ok(key.str()?.to_string_lossy()), + ObType::Pattern => Ok(key.getattr(intern!(key.py(), "pattern"))?.str()?.to_string_lossy()), ObType::Unknown => { if let Some(fallback) = extra.fallback { let next_key = fallback.call1((key,))?; diff --git a/src/serializers/ob_type.rs b/src/serializers/ob_type.rs index 7493d5abd..924de001b 100644 --- a/src/serializers/ob_type.rs +++ b/src/serializers/ob_type.rs @@ -44,6 +44,8 @@ pub struct ObTypeLookup { generator_object: PyObject, // path path_object: PyObject, + // pattern + pattern_object: PyObject, // uuid type uuid_object: PyObject, } @@ -87,6 +89,7 @@ impl ObTypeLookup { .unwrap() .to_object(py), path_object: py.import("pathlib").unwrap().getattr("Path").unwrap().to_object(py), + pattern_object: py.import("re").unwrap().getattr("Pattern").unwrap().to_object(py), uuid_object: py.import("uuid").unwrap().getattr("UUID").unwrap().to_object(py), } } @@ -150,6 +153,7 @@ impl ObTypeLookup { ObType::Enum => self.enum_object.as_ptr() as usize == ob_type, ObType::Generator => self.generator_object.as_ptr() as usize == ob_type, ObType::Path => self.path_object.as_ptr() as usize == ob_type, + ObType::Pattern => self.path_object.as_ptr() as usize == ob_type, ObType::Uuid => self.uuid_object.as_ptr() as usize == ob_type, ObType::Unknown => false, }; @@ -242,6 +246,8 @@ impl ObTypeLookup { ObType::Generator } else if ob_type == self.path_object.as_ptr() as usize { ObType::Path + } else if ob_type == self.pattern_object.as_ptr() as usize { + ObType::Pattern } else { // this allows for subtypes of the supported class types, // if `ob_type` didn't match any member of self, we try again with the next base type pointer @@ -319,6 +325,8 @@ impl ObTypeLookup { ObType::Generator } else if value.is_instance(self.path_object.as_ref(py)).unwrap_or(false) { ObType::Path + } else if value.is_instance(self.pattern_object.as_ref(py)).unwrap_or(false) { + ObType::Pattern } else { ObType::Unknown } @@ -396,6 +404,8 @@ pub enum ObType { Generator, // Path Path, + //Pattern, + Pattern, // Uuid Uuid, // unknown type diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index fa6e702fe..875a3fcf1 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -1,6 +1,7 @@ import dataclasses import json import platform +import re import sys from collections import namedtuple from datetime import date, datetime, time, timedelta, timezone @@ -437,7 +438,7 @@ def test_base64(): (lambda: MyEnum.a, {}, b'1'), (lambda: MyEnum.b, {}, b'"b"'), (lambda: [MyDataclass(1, 'a', 2), MyModel(a=2, b='b')], {}, b'[{"a":1,"b":"a"},{"a":2,"b":"b"}]'), - # # (lambda: re.compile('^regex$'), b'"^regex$"'), + (lambda: re.compile('^regex$'), {}, b'"^regex$"'), ], ) def test_encoding(any_serializer, gen_input, kwargs, expected_json): From 758bc51d4bb7e94aae0c8e6f4ce515b2d55cc3a6 Mon Sep 17 00:00:00 2001 From: Jean Arhancet Date: Thu, 25 Jan 2024 16:29:12 +0100 Subject: [PATCH 194/550] fix(uuid): validation from string (#1172) --- src/validators/uuid.rs | 15 +++++++++++++++ tests/validators/test_uuid.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/src/validators/uuid.rs b/src/validators/uuid.rs index 4cfd7a272..c0854c10f 100644 --- a/src/validators/uuid.rs +++ b/src/validators/uuid.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::types::{PyDict, PyType}; use uuid::Uuid; +use uuid::Variant; use crate::build_tools::is_strict; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; @@ -125,6 +126,20 @@ impl Validator for UuidValidator { state.floor_exactness(Exactness::Lax); } let uuid = self.get_uuid(input)?; + // This block checks if the UUID version matches the expected version and + // if the UUID variant conforms to RFC 4122. When dealing with Python inputs, + // UUIDs must adhere to RFC 4122 standards. + if let Some(expected_version) = self.version { + if uuid.get_version_num() != expected_version || uuid.get_variant() != Variant::RFC4122 { + return Err(ValError::new( + ErrorType::UuidVersion { + expected_version, + context: None, + }, + input, + )); + } + } self.create_py_uuid(py, class, &uuid) } } diff --git a/tests/validators/test_uuid.py b/tests/validators/test_uuid.py index 9b4ef60cf..eae950a63 100644 --- a/tests/validators/test_uuid.py +++ b/tests/validators/test_uuid.py @@ -24,6 +24,7 @@ ('886313e1-3b8a-5372-9b90-0c9aee199e5d', UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d')), ('c0a8f9a8-aa5e-482b-a067-9cb3a51f5c11', UUID('c0a8f9a8-aa5e-482b-a067-9cb3a51f5c11')), ('00000000-8000-4000-8000-000000000000', UUID('00000000-8000-4000-8000-000000000000')), + ('00000000-0000-4000-0000-000000000000', UUID('00000000-0000-4000-0000-000000000000')), (b'\x12\x34\x56\x78' * 4, UUID('12345678-1234-5678-1234-567812345678')), (b'\x00\x00\x00\x00' * 4, UUID('00000000-0000-0000-0000-000000000000')), (b'ebcdab58-6eb8-46fb-a190-d07a33e9eac8', UUID('ebcdab58-6eb8-46fb-a190-d07a33e9eac8')), @@ -123,6 +124,8 @@ def test_uuid_strict(input_value, expected): # `UUID.version` makes sense for RFC 4122 UUIDs only. For non RFC 4122 UUIDs Python uses `UUID.version=None` ('00000000-8000-4000-8000-000000000000', 4, UUID('00000000-8000-4000-8000-000000000000')), (UUID('00000000-8000-4000-8000-000000000000'), 4, UUID('00000000-8000-4000-8000-000000000000')), + ('00000000-0000-4000-0000-000000000000', None, UUID('00000000-0000-4000-0000-000000000000')), + (UUID('00000000-0000-4000-0000-000000000000'), None, UUID('00000000-0000-4000-0000-000000000000')), ('00000000-7fff-4000-7fff-000000000000', None, UUID('00000000-7fff-4000-7fff-000000000000')), (UUID('00000000-7fff-4000-7fff-000000000000'), None, UUID('00000000-7fff-4000-7fff-000000000000')), (UUID('00000000-7fff-4000-7fff-000000000000'), 4, Err('UUID version 4 expected')), @@ -138,6 +141,8 @@ def test_uuid_strict(input_value, expected): (UUID('0e7ac198-9acd-4c0c-b4b4-761974bf71d7'), 3, Err('UUID version 3 expected')), ('08ed0736-fb95-5cc5-85ed-37e4f3df9b29', 1, Err('UUID version 1 expected')), (UUID('08ed0736-fb95-5cc5-85ed-37e4f3df9b29'), 1, Err('UUID version 1 expected')), + ('00000000-0000-4000-0000-000000000000', 4, Err('UUID version 4 expected')), + (UUID('00000000-0000-4000-0000-000000000000'), 4, Err('UUID version 4 expected')), ], ) def test_uuid_version(input_value, version, expected): From dcaf63ece0836d8f09707d66f9d66721bfa50dd9 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 1 Feb 2024 14:04:09 +0000 Subject: [PATCH 195/550] Fix warning for tuple of wrong size in union (#1174) --- src/serializers/type_serializers/tuple.rs | 244 +++++++++++----------- tests/serializers/test_list_tuple.py | 26 +++ 2 files changed, 144 insertions(+), 126 deletions(-) diff --git a/src/serializers/type_serializers/tuple.rs b/src/serializers/type_serializers/tuple.rs index e5f225c92..5890288d6 100644 --- a/src/serializers/type_serializers/tuple.rs +++ b/src/serializers/type_serializers/tuple.rs @@ -7,8 +7,10 @@ use std::iter; use serde::ser::SerializeSeq; use crate::definitions::DefinitionsBuilder; +use crate::serializers::extra::SerCheck; use crate::serializers::type_serializers::any::AnySerializer; use crate::tools::SchemaDict; +use crate::PydanticSerializationUnexpectedValue; use super::{ infer_json_key, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, CombinedSerializer, Extra, @@ -70,52 +72,14 @@ impl TypeSerializer for TupleSerializer { let py = value.py(); let n_items = py_tuple.len(); - let mut py_tuple_iter = py_tuple.iter(); let mut items = Vec::with_capacity(n_items); - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for (index, serializer) in $serializers_iter.enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(n_items))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(serializer.to_python(element, next_include, next_exclude, extra)?); - } - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - let mut warned = false; - for (i, element) in py_tuple_iter.enumerate() { - if !warned { - extra - .warnings - .custom_warning("Unexpected extra items present in tuple".to_string()); - warned = true; - } - let op_next = - self.filter - .index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?; - if let Some((next_include, next_exclude)) = op_next { - items.push(AnySerializer.to_python(element, next_include, next_exclude, extra)?); - } - } - }; + self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| { + entry + .serializer + .to_python(entry.item, entry.include, entry.exclude, extra) + .map(|item| items.push(item)) + })??; match extra.mode { SerMode::Json => Ok(PyList::new(py, items).into_py(py)), @@ -132,35 +96,14 @@ impl TypeSerializer for TupleSerializer { fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { match key.downcast::() { Ok(py_tuple) => { - let mut py_tuple_iter = py_tuple.iter(); - let mut key_builder = KeyBuilder::new(); - let n_items = py_tuple.len(); - - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for serializer in $serializers_iter { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - key_builder.push(&serializer.json_key(element, extra)?); - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - }; + self.for_each_tuple_item_and_serializer(py_tuple, None, None, extra, |entry| { + entry + .serializer + .json_key(entry.item, extra) + .map(|key| key_builder.push(&key)) + })??; Ok(Cow::Owned(key_builder.finish())) } @@ -184,63 +127,18 @@ impl TypeSerializer for TupleSerializer { let py_tuple: &PyTuple = py_tuple.downcast().map_err(py_err_se_err)?; let n_items = py_tuple.len(); - let mut py_tuple_iter = py_tuple.iter(); let mut seq = serializer.serialize_seq(Some(n_items))?; - macro_rules! use_serializers { - ($serializers_iter:expr) => { - for (index, serializer) in $serializers_iter.enumerate() { - let element = match py_tuple_iter.next() { - Some(value) => value, - None => break, - }; - let op_next = self - .filter - .index_filter(index, include, exclude, Some(n_items)) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = - PydanticSerializer::new(element, serializer, next_include, next_exclude, extra); - seq.serialize_element(&item_serialize)?; - } - } - }; - } - - if let Some(variadic_item_index) = self.variadic_item_index { - // Need `saturating_sub` to handle items with too few elements without panicking - let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); - let serializers_iter = self.serializers[..variadic_item_index] - .iter() - .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) - .chain(self.serializers[variadic_item_index + 1..].iter()); - use_serializers!(serializers_iter); - } else { - use_serializers!(self.serializers.iter()); - let mut warned = false; - for (i, element) in py_tuple_iter.enumerate() { - if !warned { - extra - .warnings - .custom_warning("Unexpected extra items present in tuple".to_string()); - warned = true; - } - let op_next = self - .filter - .index_filter(i + self.serializers.len(), include, exclude, Some(n_items)) - .map_err(py_err_se_err)?; - if let Some((next_include, next_exclude)) = op_next { - let item_serialize = PydanticSerializer::new( - element, - &CombinedSerializer::Any(AnySerializer), - next_include, - next_exclude, - extra, - ); - seq.serialize_element(&item_serialize)?; - } - } - }; + self.for_each_tuple_item_and_serializer(py_tuple, include, exclude, extra, |entry| { + seq.serialize_element(&PydanticSerializer::new( + entry.item, + entry.serializer, + entry.include, + entry.exclude, + extra, + )) + }) + .map_err(py_err_se_err)??; seq.end() } @@ -254,6 +152,100 @@ impl TypeSerializer for TupleSerializer { fn get_name(&self) -> &str { &self.name } + + fn retry_with_lax_check(&self) -> bool { + true + } +} + +struct TupleSerializerEntry<'a, 'py> { + item: &'py PyAny, + include: Option<&'py PyAny>, + exclude: Option<&'py PyAny>, + serializer: &'a CombinedSerializer, +} + +impl TupleSerializer { + /// Try to serialize each item in the tuple with the corresponding serializer. + /// + /// If the tuple doesn't match the length of the serializer, in strict mode, an error is returned. + /// + /// The error type E is the type of the error returned by the closure, which is why there are two + /// levels of `Result`. + fn for_each_tuple_item_and_serializer( + &self, + tuple: &PyTuple, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + mut f: impl for<'a, 'py> FnMut(TupleSerializerEntry<'a, 'py>) -> Result<(), E>, + ) -> PyResult> { + let n_items = tuple.len(); + let mut py_tuple_iter = tuple.iter(); + + macro_rules! use_serializers { + ($serializers_iter:expr) => { + for (index, serializer) in $serializers_iter.enumerate() { + let element = match py_tuple_iter.next() { + Some(value) => value, + None => break, + }; + let op_next = self.filter.index_filter(index, include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + if let Err(e) = f(TupleSerializerEntry { + item: element, + include: next_include, + exclude: next_exclude, + serializer, + }) { + return Ok(Err(e)); + }; + } + } + }; + } + + if let Some(variadic_item_index) = self.variadic_item_index { + // Need `saturating_sub` to handle items with too few elements without panicking + let n_variadic_items = (n_items + 1).saturating_sub(self.serializers.len()); + let serializers_iter = self.serializers[..variadic_item_index] + .iter() + .chain(iter::repeat(&self.serializers[variadic_item_index]).take(n_variadic_items)) + .chain(self.serializers[variadic_item_index + 1..].iter()); + use_serializers!(serializers_iter); + } else if extra.check == SerCheck::Strict && n_items != self.serializers.len() { + return Err(PydanticSerializationUnexpectedValue::new_err(Some(format!( + "Expected {} items, but got {}", + self.serializers.len(), + n_items + )))); + } else { + use_serializers!(self.serializers.iter()); + let mut warned = false; + for (i, element) in py_tuple_iter.enumerate() { + if !warned { + extra + .warnings + .custom_warning("Unexpected extra items present in tuple".to_string()); + warned = true; + } + let op_next = self + .filter + .index_filter(i + self.serializers.len(), include, exclude, Some(n_items))?; + if let Some((next_include, next_exclude)) = op_next { + if let Err(e) = f(TupleSerializerEntry { + item: element, + include: next_include, + exclude: next_exclude, + serializer: &CombinedSerializer::Any(AnySerializer), + }) { + return Ok(Err(e)); + }; + } + } + }; + Ok(Ok(())) + } } pub(crate) struct KeyBuilder { diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index df6aabd8a..256a342fe 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -411,3 +411,29 @@ def test_tuple_pos_dict_key(): assert s.to_python({(1, 'a', 2): 1}, mode='json') == {'1,a,2': 1} assert s.to_json({(1, 'a'): 1}) == b'{"1,a":1}' assert s.to_json({(1, 'a', 2): 1}) == b'{"1,a,2":1}' + + +def test_tuple_wrong_size_union(): + # See https://github.com/pydantic/pydantic/issues/8677 + + f = core_schema.float_schema() + s = SchemaSerializer( + core_schema.union_schema([core_schema.tuple_schema([f, f]), core_schema.tuple_schema([f, f, f])]) + ) + assert s.to_python((1.0, 2.0)) == (1.0, 2.0) + assert s.to_python((1.0, 2.0, 3.0)) == (1.0, 2.0, 3.0) + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_python((1.0, 2.0, 3.0, 4.0)) + + assert s.to_python((1.0, 2.0), mode='json') == [1.0, 2.0] + assert s.to_python((1.0, 2.0, 3.0), mode='json') == [1.0, 2.0, 3.0] + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_python((1.0, 2.0, 3.0, 4.0), mode='json') + + assert s.to_json((1.0, 2.0)) == b'[1.0,2.0]' + assert s.to_json((1.0, 2.0, 3.0)) == b'[1.0,2.0,3.0]' + + with pytest.warns(UserWarning, match='Unexpected extra items present in tuple'): + s.to_json((1.0, 2.0, 3.0, 4.0)) From 42394acb0fdd9b7cb64ec56174be4ac3b9eca5c9 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 1 Feb 2024 12:57:54 -0600 Subject: [PATCH 196/550] Fix model computed field serializer (json) (#1187) --- src/serializers/computed_fields.rs | 2 +- tests/serializers/test_model.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/serializers/computed_fields.rs b/src/serializers/computed_fields.rs index 6f4553069..43e284fdd 100644 --- a/src/serializers/computed_fields.rs +++ b/src/serializers/computed_fields.rs @@ -75,7 +75,7 @@ impl ComputedFields { let property_name_py = computed_field.property_name_py.as_ref(model.py()); let value = model.getattr(property_name_py).map_err(py_err_se_err)?; if extra.exclude_none && value.is_none() { - return Ok(()); + continue; } if let Some((next_include, next_exclude)) = filter .key_filter(property_name_py, include, exclude) diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 4249c3015..0be83e538 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -650,6 +650,51 @@ def volume(self) -> None: assert s.to_json(Model(3, 4), exclude_none=True) == b'{"width":3,"height":4,"Area":12}' +def test_computed_field_exclude_none_different_order(): + # verify that order of computed fields doesn't matter + # issue originally reported via: https://github.com/pydantic/pydantic/issues/8691 + + @dataclasses.dataclass + class Model: + width: int + height: int + + @property + def volume(self) -> None: + return None + + @property + def area(self) -> int: + return self.width * self.height + + s = SchemaSerializer( + core_schema.model_schema( + Model, + core_schema.model_fields_schema( + { + 'width': core_schema.model_field(core_schema.int_schema()), + 'height': core_schema.model_field(core_schema.int_schema()), + }, + computed_fields=[ + core_schema.computed_field('volume', core_schema.int_schema()), + core_schema.computed_field('area', core_schema.int_schema(), alias='Area'), + ], + ), + ) + ) + assert s.to_python(Model(3, 4), exclude_none=False) == {'width': 3, 'height': 4, 'Area': 12, 'volume': None} + assert s.to_python(Model(3, 4), exclude_none=True) == {'width': 3, 'height': 4, 'Area': 12} + assert s.to_python(Model(3, 4), mode='json', exclude_none=False) == { + 'width': 3, + 'height': 4, + 'Area': 12, + 'volume': None, + } + assert s.to_python(Model(3, 4), mode='json', exclude_none=True) == {'width': 3, 'height': 4, 'Area': 12} + assert s.to_json(Model(3, 4), exclude_none=False) == b'{"width":3,"height":4,"volume":null,"Area":12}' + assert s.to_json(Model(3, 4), exclude_none=True) == b'{"width":3,"height":4,"Area":12}' + + @pytest.mark.skipif(cached_property is None, reason='cached_property is not available') def test_cached_property_alias(): @dataclasses.dataclass From 305a837e9f76872a5e91db396c8de22cc81aaffc Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:59:10 -0600 Subject: [PATCH 197/550] Version bump, 2.16.2 (#1188) --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5d113589f..40419c743 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "pydantic-core" -version = "2.16.1" +version = "2.16.2" dependencies = [ "ahash", "base64", diff --git a/Cargo.toml b/Cargo.toml index c26c91419..a10c8e0f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pydantic-core" -version = "2.16.1" +version = "2.16.2" edition = "2021" license = "MIT" homepage = "https://github.com/pydantic/pydantic-core" From 858d85df0fac21f4230e110b0020fea8825309c7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:07:37 +0000 Subject: [PATCH 198/550] Bump the python-packages group with 7 updates (#1182) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- tests/requirements-linting.txt | 6 +++--- tests/requirements.txt | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 043482646..1692fa86d 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -griffe==0.38.1 -pyright==1.1.345 -ruff==0.1.13 +griffe==0.40.0 +pyright==1.1.349 +ruff==0.1.15 mypy==1.8.0 diff --git a/tests/requirements.txt b/tests/requirements.txt index 29bf5bcad..756792f2a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,12 +1,12 @@ -coverage==7.4.0 +coverage==7.4.1 dirty-equals==0.7.1.post0 -hypothesis==6.92.9 +hypothesis==6.97.4 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is # released which removes use of deprecated utcfromtimestamp git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==7.4.4 +pytest==8.0.0 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' # pytest-examples currently depends on aiohttp via black; we don't want to build @@ -16,7 +16,7 @@ pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 pytest-timeout==2.2.0 -pytz==2023.3.post1 +pytz==2023.4 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.26.2; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' exceptiongroup==1.1; python_version < "3.11" From b393063f98c995726f775023927f6a68a5e1a2e2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:08:12 +0000 Subject: [PATCH 199/550] Bump actions/cache from 3 to 4 (#1186) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- .github/workflows/codspeed.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 826635499..42a210688 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -239,7 +239,7 @@ jobs: with: node-version: '18' - - uses: actions/cache@v3 + - uses: actions/cache@v4 id: cache-py name: cache python with: diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index fc6591314..33497021f 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -20,7 +20,7 @@ jobs: with: python-version: '3.12' - - uses: actions/cache@v3 + - uses: actions/cache@v4 id: cache-py name: cache python with: From 1fb3343cbfe4cb75de0f542dd002daa3a3f497af Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:08:20 +0000 Subject: [PATCH 200/550] Bump mymindstorm/setup-emsdk from 13 to 14 (#1185) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 42a210688..884cd663c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -304,7 +304,7 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - - uses: mymindstorm/setup-emsdk@v13 + - uses: mymindstorm/setup-emsdk@v14 with: # NOTE!: as per https://github.com/pydantic/pydantic-core/pull/149 this version needs to match the version # in node_modules/pyodide/repodata.json, to get the version, run: From ff073f0c9a70c2b4fe1bfcb1a98cf7a395a962c1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:08:45 +0000 Subject: [PATCH 201/550] Bump codecov/codecov-action from 3 to 4 (#1184) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 884cd663c..57dba20a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,7 @@ jobs: - run: coverage-prepare lcov python/pydantic_core/*.so - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 # See https://github.com/PyO3/pyo3/discussions/2781 # tests intermittently segfault with pypy and cpython 3.7 when using `coverage run ...`, hence separate job From 5ba2d52fbcf839f6655b3845bfcaef9a6a798474 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:09:18 +0000 Subject: [PATCH 202/550] Bump uraimo/run-on-arch-action from 2.6.0 to 2.7.1 (#1183) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 57dba20a2..6a4d32a4d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -611,7 +611,7 @@ jobs: merge-multiple: true path: dist - - uses: uraimo/run-on-arch-action@v2.6.0 + - uses: uraimo/run-on-arch-action@v2.7.1 name: install & test with: arch: ${{ matrix.target }} From d0a438a3847f8e194e81b33ce965b9969ce9b386 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:09:35 +0000 Subject: [PATCH 203/550] Bump smallvec from 1.11.2 to 1.13.1 (#1181) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 40419c743..28ceaf2ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,9 +526,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "speedate" diff --git a/Cargo.toml b/Cargo.toml index a10c8e0f4..ce5d70a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_ enum_dispatch = "0.3.8" serde = { version = "1.0.195", features = ["derive"] } speedate = "0.13.0" -smallvec = "1.11.2" +smallvec = "1.13.1" ahash = "0.8.7" url = "2.5.0" # idna is already required by url, added here to be explicit From a2bf4d8e3905a8a9d9c61958a55e265d8be160ca Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:09:46 +0000 Subject: [PATCH 204/550] Bump regex from 1.10.2 to 1.10.3 (#1180) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28ceaf2ff..29c8f9dec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -447,9 +447,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" dependencies = [ "aho-corasick", "memchr", @@ -459,9 +459,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ "aho-corasick", "memchr", diff --git a/Cargo.toml b/Cargo.toml index ce5d70a60..68983ae1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ [dependencies] pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } -regex = "1.10.2" +regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_order"]} From 0c804383967c6453a81b3c8e2ef0cf7004c5083f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:10:02 +0000 Subject: [PATCH 205/550] Bump uuid from 1.6.1 to 1.7.0 (#1179) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 29c8f9dec..63637bd1e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -640,9 +640,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" [[package]] name = "version_check" diff --git a/Cargo.toml b/Cargo.toml index 68983ae1d..f74a9cafb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ idna = "0.5.0" base64 = "0.21.7" num-bigint = "0.4.4" python3-dll-a = "0.2.7" -uuid = "1.6.1" +uuid = "1.7.0" jiter = {version = "0.0.6", features = ["python"]} [lib] From 80660d582a173773f20a7342d735e920a6aff503 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:10:21 +0000 Subject: [PATCH 206/550] Bump serde_json from 1.0.109 to 1.0.113 (#1178) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63637bd1e..d099e33ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -126,9 +126,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.0" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", "hashbrown", @@ -514,9 +514,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.109" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0652c533506ad7a2e353cce269330d6afd8bdfb6d75e0ace5b35aacbd7b9e9" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index f74a9cafb..d8b2ba182 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" -serde_json = {version = "1.0.109", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.113", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.195", features = ["derive"] } speedate = "0.13.0" From 44621f1dee5160c737a0749d21e0b2118fa99705 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:19:19 +0000 Subject: [PATCH 207/550] Bump serde from 1.0.195 to 1.0.196 (#1177) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d099e33ef..c00c1131e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -494,18 +494,18 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index d8b2ba182..47b9d784c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" serde_json = {version = "1.0.113", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = { version = "1.0.195", features = ["derive"] } +serde = { version = "1.0.196", features = ["derive"] } speedate = "0.13.0" smallvec = "1.13.1" ahash = "0.8.7" From fadeecc0b25c9f9533fcd2e615f63a0e7e8be7a7 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 15 Feb 2024 13:33:34 +0000 Subject: [PATCH 208/550] ci: updates for Rust 1.76 (#1191) --- src/errors/value_exception.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index 68f93d463..3152044d1 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -98,7 +98,7 @@ impl PydanticCustomError { fn __repr__(&self, py: Python) -> PyResult { let msg = self.message(py)?; - match { self.context.as_ref() } { + match self.context.as_ref() { Some(ctx) => Ok(format!("{msg} [type={}, context={}]", self.error_type, ctx.as_ref(py))), None => Ok(format!("{msg} [type={}, context=None]", self.error_type)), } @@ -173,7 +173,7 @@ impl PydanticKnownError { fn __repr__(&self, py: Python) -> PyResult { let msg = self.message(py)?; - match { self.context(py)?.as_ref() } { + match self.context(py)?.as_ref() { Some(ctx) => Ok(format!( "{msg} [type={}, context={}]", self.error_type(), From 61a656239d0e4748cd5c59a134372eab907b5cf6 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 15 Feb 2024 07:53:32 -0600 Subject: [PATCH 209/550] `date` string coerced to `datetime` shouldn't infer timezone (#1193) --- src/validators/datetime.rs | 3 +-- tests/validators/test_datetime.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/validators/datetime.rs b/src/validators/datetime.rs index 8779ea76c..94dc7ba05 100644 --- a/src/validators/datetime.rs +++ b/src/validators/datetime.rs @@ -140,7 +140,6 @@ impl Validator for DateTimeValidator { } /// In lax mode, if the input is not a datetime, we try parsing the input as a date and add the "00:00:00" time. -/// /// Ok(None) means that this is not relevant to datetimes (the input was not a date nor a string) fn datetime_from_date<'data>(input: &'data impl Input<'data>) -> Result>, ValError> { let either_date = match input.validate_date(false) { @@ -171,7 +170,7 @@ fn datetime_from_date<'data>(input: &'data impl Input<'data>) -> Result Date: Sat, 17 Feb 2024 14:22:42 +0200 Subject: [PATCH 210/550] Add benchmarks for serializing model with complete schema (#1189) --- tests/benchmarks/test_complete_benchmark.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/benchmarks/test_complete_benchmark.py b/tests/benchmarks/test_complete_benchmark.py index 57fb7645c..b13fddda7 100644 --- a/tests/benchmarks/test_complete_benchmark.py +++ b/tests/benchmarks/test_complete_benchmark.py @@ -8,7 +8,7 @@ import pytest -from pydantic_core import SchemaValidator, ValidationError, validate_core_schema +from pydantic_core import SchemaSerializer, SchemaValidator, ValidationError, validate_core_schema from .complete_schema import input_data_lax, input_data_strict, input_data_wrong, schema @@ -98,6 +98,25 @@ def test_complete_core_strict(benchmark): benchmark(v.validate_python, input_data_strict()) +@pytest.mark.benchmark(group='complete-to-python') +def test_complete_core_serializer_to_python(benchmark): + core_schema = validate_core_schema(schema()) + v = SchemaValidator(core_schema) + model = v.validate_python(input_data_lax()) + serializer = SchemaSerializer(core_schema) + assert serializer.to_python(model) == model.__dict__ + benchmark(serializer.to_python, model) + + +@pytest.mark.benchmark(group='complete-to-json') +def test_complete_core_serializer_to_json(benchmark): + core_schema = validate_core_schema(schema()) + v = SchemaValidator(core_schema) + model = v.validate_python(input_data_lax()) + serializer = SchemaSerializer(core_schema) + benchmark(serializer.to_json, model) + + @pytest.mark.benchmark(group='complete-wrong') def test_complete_core_error(benchmark): v = SchemaValidator(validate_core_schema(schema())) From ea443ba0b08bde949c9460aa2d43ad0058fd407c Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 19 Feb 2024 13:36:48 +0200 Subject: [PATCH 211/550] Prevent panicking when `__dict__` changes during iteration (#1196) --- src/serializers/fields.rs | 12 ++++++-- tests/serializers/test_model.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index cefdec1d7..c57729c46 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -5,6 +5,7 @@ use pyo3::types::{PyDict, PyString}; use ahash::AHashMap; use serde::ser::SerializeMap; +use smallvec::SmallVec; use crate::serializers::extra::SerCheck; use crate::PydanticSerializationUnexpectedValue; @@ -321,7 +322,7 @@ impl TypeSerializer for GeneralFieldsSerializer { return infer_to_python(value, include, exclude, &td_extra); }; - let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?; + let output_dict = self.main_to_python(py, dict_items(main_dict), include, exclude, td_extra)?; // this is used to include `__pydantic_extra__` in serialization on models if let Some(extra_dict) = extra_dict { @@ -373,7 +374,7 @@ impl TypeSerializer for GeneralFieldsSerializer { // NOTE! As above, we maintain the order of the input dict assuming that's right // we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used let mut map = self.main_serde_serialize( - main_dict.iter().map(Ok), + dict_items(main_dict), expected_len, serializer, include, @@ -408,3 +409,10 @@ impl TypeSerializer for GeneralFieldsSerializer { fn key_str(key: &PyAny) -> PyResult<&str> { key.downcast::()?.to_str() } + +fn dict_items(main_dict: &PyDict) -> impl Iterator> { + // Collect items before iterating to prevent panic on dict mutation. + // Use a SmallVec to avoid heap allocation for models with a reasonable number of fields. + let main_items: SmallVec<[_; 16]> = main_dict.iter().collect(); + main_items.into_iter().map(Ok) +} diff --git a/tests/serializers/test_model.py b/tests/serializers/test_model.py index 0be83e538..c93825839 100644 --- a/tests/serializers/test_model.py +++ b/tests/serializers/test_model.py @@ -456,6 +456,58 @@ def ser_x(self, v: Any, _) -> str: assert s.to_python(Model(x=1000)) == {'x': '1_000'} +@pytest.mark.skipif(cached_property is None, reason='cached_property is not available') +def test_field_serializer_cached_property(): + @dataclasses.dataclass + class Model: + x: int + y: int + + @cached_property + def x_formatted(self) -> str: + return f'{self.x:_}' + + # This is a @computed_field + @cached_property + def y_formatted(self) -> str: + return f'{self.y:_}' + + def ser_x(self, v: Any, _) -> str: + assert self.x == 1_000 == v + return self.x_formatted + + def ser_y(self, v: Any, _) -> str: + assert self.y == 2_000 == v + return self.y_formatted + + s = SchemaSerializer( + core_schema.model_schema( + Model, + core_schema.model_fields_schema( + { + 'x': core_schema.model_field( + core_schema.int_schema( + serialization=core_schema.plain_serializer_function_ser_schema( + Model.ser_x, is_field_serializer=True, info_arg=True + ) + ) + ), + 'y': core_schema.model_field( + core_schema.int_schema( + serialization=core_schema.plain_serializer_function_ser_schema( + Model.ser_y, is_field_serializer=True, info_arg=True + ) + ) + ), + }, + computed_fields=[core_schema.computed_field('y_formatted', core_schema.str_schema())], + ), + ) + ) + assert s.to_python(Model(x=1000, y=2000)) == {'x': '1_000', 'y': '2_000', 'y_formatted': '2_000'} + assert s.to_json(Model(x=1000, y=2000)) == b'{"x":"1_000","y":"2_000","y_formatted":"2_000"}' + + def test_function_wrap_field_serializer_to_python(): @dataclasses.dataclass class Model: From 8c6e2bdd8b9ed9354d92587e69cf7ee2eff7ce7e Mon Sep 17 00:00:00 2001 From: Sam Dobson <1309834+samdobson@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:57:02 +0000 Subject: [PATCH 212/550] Update pyodide to 0.25.0 (#1199) --- .github/workflows/ci.yml | 2 +- package.json | 2 +- wasm-preview/run_tests.py | 2 +- wasm-preview/worker.js | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6a4d32a4d..5c1675b62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -309,7 +309,7 @@ jobs: # NOTE!: as per https://github.com/pydantic/pydantic-core/pull/149 this version needs to match the version # in node_modules/pyodide/repodata.json, to get the version, run: # `cat node_modules/pyodide/repodata.json | python -m json.tool | rg platform` - version: '3.1.32' + version: '3.1.46' actions-cache-folder: emsdk-cache - run: pip install 'maturin>=1,<2' 'ruff==0.1.3' typing_extensions diff --git a/package.json b/package.json index 4683dc942..8bb850741 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,7 @@ "main": "tests/emscripten_runner.js", "dependencies": { "prettier": "^2.7.1", - "pyodide": "^0.23.0" + "pyodide": "^0.25.0" }, "scripts": { "test": "node tests/emscripten_runner.js", diff --git a/wasm-preview/run_tests.py b/wasm-preview/run_tests.py index f17b314a6..538c10e49 100644 --- a/wasm-preview/run_tests.py +++ b/wasm-preview/run_tests.py @@ -21,7 +21,7 @@ async def main(tests_zip: str, tag_name: str): # File saved on the GH release pydantic_core_wheel = ( 'https://githubproxy.samuelcolvin.workers.dev/pydantic/pydantic-core/releases/' - f'download/{tag_name}/pydantic_core-{tag_name.lstrip("v")}-cp311-cp311-emscripten_3_1_32_wasm32.whl' + f'download/{tag_name}/pydantic_core-{tag_name.lstrip("v")}-cp311-cp311-emscripten_3_1_46_wasm32.whl' ) zip_file = ZipFile(BytesIO(base64.b64decode(tests_zip))) count = 0 diff --git a/wasm-preview/worker.js b/wasm-preview/worker.js index 7ef438974..b769281ef 100644 --- a/wasm-preview/worker.js +++ b/wasm-preview/worker.js @@ -89,7 +89,7 @@ async function main() { get(`./run_tests.py?v=${Date.now()}`, 'text'), // e4cf2e2 commit matches the pydantic-core wheel being used, so tests should pass get(zip_url, 'blob'), - importScripts('https://cdn.jsdelivr.net/pyodide/v0.23.0/full/pyodide.js'), + importScripts('https://cdn.jsdelivr.net/pyodide/v0.25.0/full/pyodide.js'), ]); const pyodide = await loadPyodide(); From c2c90fd001d8df98fc0e864674d49f0a144fc85e Mon Sep 17 00:00:00 2001 From: Sarbjot Singh Date: Tue, 20 Feb 2024 05:44:45 -0500 Subject: [PATCH 213/550] Fix TzInfo equality check based on offset (#1197) Co-authored-by: David Hewitt --- src/input/datetime.rs | 15 +++++++++++++-- tests/test_tzinfo.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/input/datetime.rs b/src/input/datetime.rs index 0a3cdf929..ebb5675f2 100644 --- a/src/input/datetime.rs +++ b/src/input/datetime.rs @@ -563,8 +563,19 @@ impl TzInfo { hasher.finish() } - fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { - op.matches(self.seconds.cmp(&other.seconds)) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult> { + let py = other.py(); + if other.is_instance_of::() { + let offset_delta = other.call_method1(intern!(py, "utcoffset"), (py.None(),))?; + if offset_delta.is_none() { + return Ok(py.NotImplemented()); + } + let offset_seconds: f64 = offset_delta.call_method0(intern!(py, "total_seconds"))?.extract()?; + let offset = offset_seconds.round() as i32; + Ok(op.matches(self.seconds.cmp(&offset)).into_py(py)) + } else { + Ok(py.NotImplemented()) + } } fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult> { diff --git a/tests/test_tzinfo.py b/tests/test_tzinfo.py index cb67b737e..949c9175d 100644 --- a/tests/test_tzinfo.py +++ b/tests/test_tzinfo.py @@ -1,11 +1,15 @@ import copy import functools import pickle +import sys import unittest from datetime import datetime, timedelta, timezone, tzinfo from pydantic_core import SchemaValidator, TzInfo, core_schema +if sys.version_info >= (3, 9): + from zoneinfo import ZoneInfo + class _ALWAYS_EQ: """ @@ -80,6 +84,7 @@ class TestTzInfo(unittest.TestCase): def setUp(self): self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds()) self.EST = TzInfo(-timedelta(hours=5).total_seconds()) + self.UTC = TzInfo(timedelta(0).total_seconds()) self.DT = datetime(2010, 1, 1) def test_str(self): @@ -163,6 +168,17 @@ def test_comparison(self): self.assertFalse(tz <= SMALLEST) self.assertTrue(tz >= SMALLEST) + # offset based comparion tests for tzinfo derived classes like datetime.timezone. + utcdatetime = self.DT.replace(tzinfo=timezone.utc) + self.assertTrue(tz == utcdatetime.tzinfo) + estdatetime = self.DT.replace(tzinfo=timezone(-timedelta(hours=5))) + self.assertTrue(self.EST == estdatetime.tzinfo) + self.assertTrue(tz > estdatetime.tzinfo) + if sys.version_info >= (3, 9) and sys.platform == 'linux': + self.assertFalse(tz == ZoneInfo('Europe/London')) + with self.assertRaises(TypeError): + tz > ZoneInfo('Europe/London') + def test_copy(self): for tz in self.ACDT, self.EST: tz_copy = copy.copy(tz) From 47aff70a7fb57c7fee3e63714fe7343d9cfbe4a5 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 20 Feb 2024 14:17:06 +0000 Subject: [PATCH 214/550] ci: if tzdata not available, skip comparing to `ZoneInfo` (#1200) --- tests/test_tzinfo.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_tzinfo.py b/tests/test_tzinfo.py index 949c9175d..e67bf9098 100644 --- a/tests/test_tzinfo.py +++ b/tests/test_tzinfo.py @@ -8,7 +8,7 @@ from pydantic_core import SchemaValidator, TzInfo, core_schema if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo, ZoneInfoNotFoundError class _ALWAYS_EQ: @@ -174,10 +174,17 @@ def test_comparison(self): estdatetime = self.DT.replace(tzinfo=timezone(-timedelta(hours=5))) self.assertTrue(self.EST == estdatetime.tzinfo) self.assertTrue(tz > estdatetime.tzinfo) + if sys.version_info >= (3, 9) and sys.platform == 'linux': - self.assertFalse(tz == ZoneInfo('Europe/London')) - with self.assertRaises(TypeError): - tz > ZoneInfo('Europe/London') + try: + europe_london = ZoneInfo('Europe/London') + except ZoneInfoNotFoundError: + # tz data not available + pass + else: + self.assertFalse(tz == europe_london) + with self.assertRaises(TypeError): + tz > europe_london def test_copy(self): for tz in self.ACDT, self.EST: From 1a9046801ffdafe4bf29a2e2c15996b6efec852c Mon Sep 17 00:00:00 2001 From: Hungtsetse <33526088+hungtsetse@users.noreply.github.com> Date: Mon, 26 Feb 2024 19:39:30 +0800 Subject: [PATCH 215/550] Trimming str before parsing to int and float (#1203) Co-authored-by: Hung Tse Lee --- src/input/shared.rs | 3 ++- tests/validators/test_float.py | 1 + tests/validators/test_int.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/input/shared.rs b/src/input/shared.rs index 591c5abfc..c81dff1e8 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -72,6 +72,7 @@ fn strip_underscores(s: &str) -> Option { /// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and /// https://github.com/python/cpython/issues/95778 for more info in that length bound pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult> { + let str = str.trim(); let len = str.len(); if len > 4300 { Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)) @@ -96,7 +97,7 @@ pub fn str_as_int<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult< /// parse a float as a float pub fn str_as_float<'s, 'l>(input: &'s impl Input<'s>, str: &'l str) -> ValResult> { - match str.parse() { + match str.trim().parse() { Ok(float) => Ok(EitherFloat::F64(float)), Err(_) => match strip_underscores(str).and_then(|stripped| stripped.parse().ok()) { Some(float) => Ok(EitherFloat::F64(float)), diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 56c03d40e..5af0d4adb 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -20,6 +20,7 @@ (1, 1), (42, 42), ('42', 42), + (' 42.1 ', 42.1), ('42.123', 42.123), (42.0, 42), (42.5, 42.5), diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 35a13f6a7..24928718d 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -21,6 +21,7 @@ (0, 0), ('0', 0), (1, 1), + (' 1 ', 1), (42, 42), ('42', 42), (42.0, 42), From 7779206d1dd3a8a2bb0203c727cf7a865a3fd672 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 26 Feb 2024 14:29:38 +0000 Subject: [PATCH 216/550] support Rust 1.70 as MSRV (#1206) --- .github/workflows/ci.yml | 43 ++++++++++++++++++++++++++++------ .github/workflows/codspeed.yml | 2 -- Cargo.toml | 1 + src/serializers/fields.rs | 10 ++++---- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5c1675b62..e0a7cb9b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,8 +24,6 @@ jobs: - id: cache-rust name: cache rust uses: Swatinem/rust-cache@v2 - with: - key: coverage-v2 - run: cargo install rustfilt coverage-prepare if: steps.cache-rust.outputs.cache-hit != 'true' @@ -120,14 +118,49 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 + + - name: set up python + uses: actions/setup-python@v5 with: - key: ${{ matrix.os }}-v1 + python-version: '3.11' + + - run: pip install -r tests/requirements.txt + + - run: pip install -e . + env: + RUST_BACKTRACE: 1 + + - run: pip freeze + + - run: pytest + + - run: cargo test + + test-msrv: + name: test MSRV + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 - name: set up python uses: actions/setup-python@v5 with: python-version: '3.11' + - name: resolve MSRV + id: resolve-msrv + run: + echo MSRV=`python -c 'import tomllib; print(tomllib.load(open("Cargo.toml", "rb"))["package"]["rust-version"])'` >> $GITHUB_OUTPUT + + - name: install rust MSRV + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ steps.resolve-msrv.outputs.MSRV }} + + - name: cache rust + uses: Swatinem/rust-cache@v2 + - run: pip install -r tests/requirements.txt - run: pip install -e . @@ -164,8 +197,6 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - with: - key: test-debug - run: pip install -r tests/requirements.txt - run: make build-dev @@ -197,8 +228,6 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - with: - key: test-pydantic-integration - name: install deps run: | diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 33497021f..332c6ae58 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -48,8 +48,6 @@ jobs: - name: cache rust uses: Swatinem/rust-cache@v2 - with: - key: v1 - name: Compile pydantic-core for profiling run: make build-profiling diff --git a/Cargo.toml b/Cargo.toml index 47b9d784c..7ba4759da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ include = [ "!tests/.pytest_cache", "!*.so", ] +rust-version = "1.70" [dependencies] pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } diff --git a/src/serializers/fields.rs b/src/serializers/fields.rs index c57729c46..4d192d344 100644 --- a/src/serializers/fields.rs +++ b/src/serializers/fields.rs @@ -147,7 +147,7 @@ impl GeneralFieldsSerializer { } } - pub fn main_to_python<'py>( + pub(crate) fn main_to_python<'py>( &self, py: Python<'py>, main_iter: impl Iterator>, @@ -212,7 +212,7 @@ impl GeneralFieldsSerializer { } } - pub fn main_serde_serialize<'py, S: serde::ser::Serializer>( + pub(crate) fn main_serde_serialize<'py, S: serde::ser::Serializer>( &self, main_iter: impl Iterator>, expected_len: usize, @@ -258,7 +258,7 @@ impl GeneralFieldsSerializer { Ok(map) } - pub fn add_computed_fields_python( + pub(crate) fn add_computed_fields_python( &self, model: Option<&PyAny>, output_dict: &PyDict, @@ -275,7 +275,7 @@ impl GeneralFieldsSerializer { Ok(()) } - pub fn add_computed_fields_json( + pub(crate) fn add_computed_fields_json( &self, model: Option<&PyAny>, map: &mut S::SerializeMap, @@ -291,7 +291,7 @@ impl GeneralFieldsSerializer { Ok(()) } - pub fn computed_field_count(&self) -> usize { + pub(crate) fn computed_field_count(&self) -> usize { option_length!(self.computed_fields) } } From 7fe78edc8582c3863b217ed864c5a35b4e08f1ae Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:56:23 -0700 Subject: [PATCH 217/550] Fix stack overflow due to recursion in some recursive serializer schemas. (#1198) Co-authored-by: David Hewitt --- src/definitions.rs | 67 +++++++++---- .../type_serializers/definitions.rs | 33 ++++++- tests/test.rs | 98 +++++++++++++++++++ 3 files changed, 174 insertions(+), 24 deletions(-) diff --git a/src/definitions.rs b/src/definitions.rs index 46a77196d..6fefb8c4a 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -4,6 +4,7 @@ /// We use DefinitionsBuilder to collect the references / definitions into a single vector /// and then get a definition from a reference using an integer id (just for performance of not using a HashMap) use std::{ + borrow::Borrow, collections::hash_map::Entry, fmt::Debug, sync::{ @@ -194,23 +195,39 @@ impl DefinitionsBuilder { } } -struct LazyName { - initialized: OnceLock, +/// Because definitions can create recursive structures, we often need to be able to populate +/// values lazily from these structures in a way that avoids infinite recursion. This structure +/// avoids infinite recursion by returning a default value when a recursion loop is detected. +pub(crate) struct RecursionSafeCache { + cache: OnceLock, in_recursion: AtomicBool, } -impl LazyName { - fn new() -> Self { +impl Clone for RecursionSafeCache { + fn clone(&self) -> Self { Self { - initialized: OnceLock::new(), + cache: self.cache.clone(), in_recursion: AtomicBool::new(false), } } +} - /// Gets the validator name, returning the default in the case of recursion loops - fn get_or_init(&self, init: impl FnOnce() -> String) -> &str { - if let Some(s) = self.initialized.get() { - return s.as_str(); +impl RecursionSafeCache { + /// Creates a new RecursionSafeCache + pub(crate) fn new() -> Self { + Self { + cache: OnceLock::new(), + in_recursion: AtomicBool::new(false), + } + } + + /// Gets or initialized the cached value, returning the default in the case of recursion loops + pub(crate) fn get_or_init(&self, init: impl FnOnce() -> T, recursive_default: &'static D) -> &D + where + T: Borrow, + { + if let Some(cached) = self.cache.get() { + return cached.borrow(); } if self @@ -218,25 +235,35 @@ impl LazyName { .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) .is_err() { - return "..."; + return recursive_default; } - let result = self.initialized.get_or_init(init).as_str(); + let result = self.cache.get_or_init(init).borrow(); self.in_recursion.store(false, Ordering::SeqCst); result } + + /// Gets the value, if it is set + fn get(&self) -> Option<&T> { + self.cache.get() + } } -impl Debug for LazyName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.initialized.get().map_or("...", String::as_str).fmt(f) +#[derive(Clone)] +struct LazyName(RecursionSafeCache); + +impl LazyName { + fn new() -> Self { + Self(RecursionSafeCache::new()) + } + + /// Gets the validator name, returning the default in the case of recursion loops + fn get_or_init(&self, init: impl FnOnce() -> String) -> &str { + self.0.get_or_init(init, "...") } } -impl Clone for LazyName { - fn clone(&self) -> Self { - Self { - initialized: OnceLock::new(), - in_recursion: AtomicBool::new(false), - } +impl Debug for LazyName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.get().map_or("...", String::as_str).fmt(f) } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 2f98a94e0..3e2d48b19 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -4,8 +4,8 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; -use crate::definitions::DefinitionRef; use crate::definitions::DefinitionsBuilder; +use crate::definitions::{DefinitionRef, RecursionSafeCache}; use crate::tools::SchemaDict; @@ -39,9 +39,28 @@ impl BuildSerializer for DefinitionsSerializerBuilder { } } -#[derive(Debug, Clone)] pub struct DefinitionRefSerializer { definition: DefinitionRef, + retry_with_lax_check: RecursionSafeCache, +} + +// TODO(DH): Remove the need to clone serializers +impl Clone for DefinitionRefSerializer { + fn clone(&self) -> Self { + Self { + definition: self.definition.clone(), + retry_with_lax_check: RecursionSafeCache::new(), + } + } +} + +impl std::fmt::Debug for DefinitionRefSerializer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DefinitionRefSerializer") + .field("definition", &self.definition) + .field("retry_with_lax_check", &self.retry_with_lax_check()) + .finish() + } } impl BuildSerializer for DefinitionRefSerializer { @@ -54,7 +73,11 @@ impl BuildSerializer for DefinitionRefSerializer { ) -> PyResult { let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?; let definition = definitions.get_definition(schema_ref); - Ok(Self { definition }.into()) + Ok(Self { + definition, + retry_with_lax_check: RecursionSafeCache::new(), + } + .into()) } } @@ -101,6 +124,8 @@ impl TypeSerializer for DefinitionRefSerializer { } fn retry_with_lax_check(&self) -> bool { - self.definition.read(|s| s.unwrap().retry_with_lax_check()) + *self + .retry_with_lax_check + .get_or_init(|| self.definition.read(|s| s.unwrap().retry_with_lax_check()), &false) } } diff --git a/tests/test.rs b/tests/test.rs index e0b76a4d8..4a5776e7d 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -117,4 +117,102 @@ json_input = '{"a": "something"}' assert_eq!(repr, "{'a': 'something'}"); }); } + + #[test] + fn test_segfault_for_recursive_schemas() { + Python::with_gil(|py| { + let code = r" +schema = { + 'type': 'definitions', + 'schema': { + 'type': 'definition-ref', + 'schema_ref': '__main__.JSONData:4303261344' + }, + 'definitions': [ + { + 'type': 'union', + 'choices': [ + { + 'type': 'dict', + 'keys_schema': {'type': 'str'}, + 'values_schema': { + 'type': 'definition-ref', + 'schema_ref': '__main__.JSONData:4303261344' + }, + 'strict': False + }, + { + 'type': 'list', + 'items_schema': { + 'type': 'definition-ref', + 'schema_ref': '__main__.JSONData:4303261344' + }, + 'strict': False + } + ], + 'ref': '__main__.JSONData:4303261344' + } + ] +} +dump_json_input_1 = 1 +dump_json_input_2 = {'a': 'something'} + "; + let locals = PyDict::new(py); + py.run(code, None, Some(locals)).unwrap(); + let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); + let dump_json_input_1: &PyAny = locals + .get_item("dump_json_input_1") + .unwrap() + .unwrap() + .extract() + .unwrap(); + let dump_json_input_2: &PyAny = locals + .get_item("dump_json_input_2") + .unwrap() + .unwrap() + .extract() + .unwrap(); + let binding = SchemaSerializer::py_new(py, schema, None) + .unwrap() + .to_json( + py, + dump_json_input_1, + None, + None, + None, + false, + false, + false, + false, + false, + false, + None, + ) + .unwrap(); + let serialization_result: &PyAny = binding.extract(py).unwrap(); + let repr = format!("{}", serialization_result.repr().unwrap()); + assert_eq!(repr, "b'1'"); + + let binding = SchemaSerializer::py_new(py, schema, None) + .unwrap() + .to_json( + py, + dump_json_input_2, + None, + None, + None, + false, + false, + false, + false, + false, + false, + None, + ) + .unwrap(); + let serialization_result: &PyAny = binding.extract(py).unwrap(); + let repr = format!("{}", serialization_result.repr().unwrap()); + assert_eq!(repr, "b'{\"a\":\"something\"}'"); + }); + } } From ad5007dc86e9d18192377f5fb41659699097c50b Mon Sep 17 00:00:00 2001 From: stonebig Date: Mon, 4 Mar 2024 14:55:38 +0100 Subject: [PATCH 218/550] Update pyproject.toml to get a 'Summary' metadata (#1214) Co-authored-by: David Hewitt --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5cdbca806..9683cf642 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ build-backend = 'maturin' [project] name = 'pydantic_core' +description = "Core functionality for Pydantic validation and serialization" requires-python = '>=3.8' authors = [ {name = 'Samuel Colvin', email = 's@muelcolvin.com'} From 68a7e968722bda860f78e4fbcf212690e003bb86 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:18:33 +0000 Subject: [PATCH 219/550] Bump serde_json from 1.0.113 to 1.0.114 (#1211) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c00c1131e..c78b54265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -514,9 +514,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "indexmap", "itoa", diff --git a/Cargo.toml b/Cargo.toml index 7ba4759da..9abfac444 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" -serde_json = {version = "1.0.113", features = ["arbitrary_precision", "preserve_order"]} +serde_json = {version = "1.0.114", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.196", features = ["derive"] } speedate = "0.13.0" From 1f905160ecc5aeb3e56423beffec067d5bd8340d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:19:16 +0000 Subject: [PATCH 220/550] Bump ahash from 0.8.7 to 0.8.10 (#1210) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c78b54265..cc8525f03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "8b79b82693f705137f8fb9b37871d99e4f9a7df12b917eed79c3d3954830a60b" dependencies = [ "cfg-if", "getrandom", diff --git a/Cargo.toml b/Cargo.toml index 9abfac444..90b1cb4d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ enum_dispatch = "0.3.8" serde = { version = "1.0.196", features = ["derive"] } speedate = "0.13.0" smallvec = "1.13.1" -ahash = "0.8.7" +ahash = "0.8.10" url = "2.5.0" # idna is already required by url, added here to be explicit idna = "0.5.0" From 19769b8d62c6aafb6b46e4980d7bae21c4c6597c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:24:40 +0000 Subject: [PATCH 221/550] Bump pyo3 from 0.20.2 to 0.20.3 (#1212) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 28 ++++++++++++++++++---------- Cargo.toml | 4 ++-- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc8525f03..ef9843b7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,12 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "proc-macro2" version = "1.0.76" @@ -357,9 +363,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ "cfg-if", "indoc", @@ -367,6 +373,7 @@ dependencies = [ "memoffset", "num-bigint", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -375,9 +382,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "python3-dll-a", @@ -386,9 +393,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -396,9 +403,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -408,12 +415,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote", "syn", ] diff --git a/Cargo.toml b/Cargo.toml index 90b1cb4d9..c437e4d6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ rust-version = "1.70" [dependencies] -pyo3 = { version = "0.20.2", features = ["generate-import-lib", "num-bigint"] } +pyo3 = { version = "0.20.3", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.25.3" @@ -71,7 +71,7 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version = "0.20.2", features = ["auto-initialize"] } +pyo3 = { version = "0.20.3", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" From 6d23acdb1ba547983d7736212f34cc9578dbe831 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:30:20 +0000 Subject: [PATCH 222/550] Bump strum_macros from 0.25.3 to 0.26.1 (#1208) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 19 ++++++++++++++++--- Cargo.toml | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ef9843b7e..b74074c8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -355,7 +355,7 @@ dependencies = [ "smallvec", "speedate", "strum", - "strum_macros", + "strum_macros 0.26.1", "url", "uuid", "version_check", @@ -545,7 +545,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "242f76c50fd18cbf098607090ade73a08d39cfd84ea835f3796a2c855223b19b" dependencies = [ "strum", - "strum_macros", + "strum_macros 0.25.3", ] [[package]] @@ -560,7 +560,7 @@ version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" dependencies = [ - "strum_macros", + "strum_macros 0.25.3", ] [[package]] @@ -576,6 +576,19 @@ dependencies = [ "syn", ] +[[package]] +name = "strum_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "syn" version = "2.0.48" diff --git a/Cargo.toml b/Cargo.toml index c437e4d6a..ea3e3d73c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.70" pyo3 = { version = "0.20.3", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } -strum_macros = "0.25.3" +strum_macros = "0.26.1" serde_json = {version = "1.0.114", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" serde = { version = "1.0.196", features = ["derive"] } From a98b8b040a92e49561305e200f1202616cd8c3ac Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 14:53:23 +0000 Subject: [PATCH 223/550] Bump the python-packages group with 7 updates (#1213) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: David Hewitt --- Makefile | 4 +-- generate_self_schema.py | 1 + pyproject.toml | 2 ++ python/pydantic_core/core_schema.py | 39 +++++++-------------- tests/benchmarks/test_complete_benchmark.py | 1 + tests/benchmarks/test_micro_benchmarks.py | 1 + tests/requirements-linting.txt | 6 ++-- tests/requirements.txt | 8 ++--- tests/test_typing.py | 12 +++---- 9 files changed, 31 insertions(+), 43 deletions(-) diff --git a/Makefile b/Makefile index 8fbe91389..0b361de63 100644 --- a/Makefile +++ b/Makefile @@ -90,13 +90,13 @@ build-wasm: .PHONY: format format: - ruff --fix $(sources) + ruff check --fix $(sources) ruff format $(sources) cargo fmt .PHONY: lint-python lint-python: - ruff $(sources) + ruff check $(sources) ruff format --check $(sources) $(mypy-stubtest) griffe dump -f -d google -LWARNING -o/dev/null python/pydantic_core diff --git a/generate_self_schema.py b/generate_self_schema.py index 8d27247d6..3aef99b18 100644 --- a/generate_self_schema.py +++ b/generate_self_schema.py @@ -4,6 +4,7 @@ The schema is generated from `python/pydantic_core/core_schema.py`. """ + from __future__ import annotations as _annotations import decimal diff --git a/pyproject.toml b/pyproject.toml index 9683cf642..d56dc1255 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,8 @@ features = ["pyo3/extension-module"] [tool.ruff] line-length = 120 + +[tool.ruff.lint] extend-select = ['Q', 'RUF100', 'C90', 'I'] extend-ignore = [ 'E721', # using type() instead of isinstance() - we use this in tests diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 31bf48782..bef04bc3a 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -117,51 +117,39 @@ class CoreConfig(TypedDict, total=False): class SerializationInfo(Protocol): @property - def include(self) -> IncExCall: - ... + def include(self) -> IncExCall: ... @property - def exclude(self) -> IncExCall: - ... + def exclude(self) -> IncExCall: ... @property - def mode(self) -> str: - ... + def mode(self) -> str: ... @property - def by_alias(self) -> bool: - ... + def by_alias(self) -> bool: ... @property - def exclude_unset(self) -> bool: - ... + def exclude_unset(self) -> bool: ... @property - def exclude_defaults(self) -> bool: - ... + def exclude_defaults(self) -> bool: ... @property - def exclude_none(self) -> bool: - ... + def exclude_none(self) -> bool: ... @property - def round_trip(self) -> bool: - ... + def round_trip(self) -> bool: ... - def mode_is_json(self) -> bool: - ... + def mode_is_json(self) -> bool: ... - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... class FieldSerializationInfo(SerializationInfo, Protocol): @property - def field_name(self) -> str: - ... + def field_name(self) -> str: ... class ValidationInfo(Protocol): @@ -305,8 +293,7 @@ def plain_serializer_function_ser_schema( class SerializerFunctionWrapHandler(Protocol): # pragma: no cover - def __call__(self, __input_value: Any, __index_key: int | str | None = None) -> Any: - ... + def __call__(self, __input_value: Any, __index_key: int | str | None = None) -> Any: ... # (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any diff --git a/tests/benchmarks/test_complete_benchmark.py b/tests/benchmarks/test_complete_benchmark.py index b13fddda7..3217efe13 100644 --- a/tests/benchmarks/test_complete_benchmark.py +++ b/tests/benchmarks/test_complete_benchmark.py @@ -1,6 +1,7 @@ """ General benchmarks that attempt to cover all field types, through by no means all uses of all field types. """ + import json from datetime import date, datetime, time from decimal import Decimal diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index c2320427c..6086c5bd7 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -1,6 +1,7 @@ """ Numerous benchmarks of specific functionality. """ + import decimal import json import platform diff --git a/tests/requirements-linting.txt b/tests/requirements-linting.txt index 1692fa86d..8e3ea74da 100644 --- a/tests/requirements-linting.txt +++ b/tests/requirements-linting.txt @@ -1,4 +1,4 @@ -griffe==0.40.0 -pyright==1.1.349 -ruff==0.1.15 +griffe==0.41.0 +pyright==1.1.352 +ruff==0.3.0 mypy==1.8.0 diff --git a/tests/requirements.txt b/tests/requirements.txt index 756792f2a..b2c8dba6b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,12 +1,12 @@ -coverage==7.4.1 +coverage==7.4.3 dirty-equals==0.7.1.post0 -hypothesis==6.97.4 +hypothesis==6.98.15 # TODO: remove manual override for dateutil once a version newer than 2.8.2 is # released which removes use of deprecated utcfromtimestamp git+https://github.com/dateutil/dateutil.git@f2293200747fb03d56c6c5997bfebeabe703576f # pandas doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux pandas==2.1.3; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' -pytest==8.0.0 +pytest==8.0.2 # we run codspeed benchmarks on x86_64 CPython (i.e. native github actions architecture) pytest-codspeed~=2.2.0; implementation_name == "cpython" and platform_machine == 'x86_64' # pytest-examples currently depends on aiohttp via black; we don't want to build @@ -16,7 +16,7 @@ pytest-speed==0.3.5 pytest-mock==3.11.1 pytest-pretty==1.2.0 pytest-timeout==2.2.0 -pytz==2023.4 +pytz==2024.1 # numpy doesn't offer prebuilt wheels for all versions and platforms we test in CI e.g. aarch64 musllinux numpy==1.26.2; python_version >= "3.9" and python_version < "3.13" and implementation_name == "cpython" and platform_machine == 'x86_64' exceptiongroup==1.1; python_version < "3.11" diff --git a/tests/test_typing.py b/tests/test_typing.py index fb9f3e949..00f58e31f 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -19,20 +19,16 @@ class Foo: bar: str -def foo(bar: str) -> None: - ... +def foo(bar: str) -> None: ... -def validator_deprecated(value: Any, info: core_schema.FieldValidationInfo) -> None: - ... +def validator_deprecated(value: Any, info: core_schema.FieldValidationInfo) -> None: ... -def validator(value: Any, info: core_schema.ValidationInfo) -> None: - ... +def validator(value: Any, info: core_schema.ValidationInfo) -> None: ... -def wrap_validator(value: Any, call_next: Callable[[Any], Any], info: core_schema.ValidationInfo) -> None: - ... +def wrap_validator(value: Any, call_next: Callable[[Any], Any], info: core_schema.ValidationInfo) -> None: ... def test_schema_typing() -> None: From f669db9c054cceb4248784e4253edc8f72fa81ad Mon Sep 17 00:00:00 2001 From: Victorien <65306057+Viicos@users.noreply.github.com> Date: Tue, 5 Mar 2024 18:26:46 +0100 Subject: [PATCH 224/550] Use PEP570 syntax (#1216) --- python/pydantic_core/_pydantic_core.pyi | 2 +- python/pydantic_core/core_schema.py | 30 ++++++++++++------------- tests/test_typing.py | 8 +++---- tests/validators/test_model_fields.py | 10 ++++----- tests/validators/test_typed_dict.py | 4 ++-- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index a7b727f86..a6001f63c 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -73,7 +73,7 @@ class Some(Generic[_T]): Returns the value wrapped by `Some`. """ @classmethod - def __class_getitem__(cls, __item: Any) -> Type[Self]: ... + def __class_getitem__(cls, item: Any, /) -> Type[Self]: ... @final class SchemaValidator: diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index bef04bc3a..ba8ce15ad 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -225,13 +225,13 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema: return SimpleSerSchema(type=type) -# (__input_value: Any) -> Any +# (input_value: Any, /) -> Any GeneralPlainNoInfoSerializerFunction = Callable[[Any], Any] -# (__input_value: Any, __info: FieldSerializationInfo) -> Any +# (input_value: Any, info: FieldSerializationInfo, /) -> Any GeneralPlainInfoSerializerFunction = Callable[[Any, SerializationInfo], Any] -# (__model: Any, __input_value: Any) -> Any +# (model: Any, input_value: Any, /) -> Any FieldPlainNoInfoSerializerFunction = Callable[[Any, Any], Any] -# (__model: Any, __input_value: Any, __info: FieldSerializationInfo) -> Any +# (model: Any, input_value: Any, info: FieldSerializationInfo, /) -> Any FieldPlainInfoSerializerFunction = Callable[[Any, Any, FieldSerializationInfo], Any] SerializerFunction = Union[ GeneralPlainNoInfoSerializerFunction, @@ -275,7 +275,7 @@ def plain_serializer_function_ser_schema( function: The function to use for serialization is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, and `info` includes `field_name` - info_arg: Whether the function takes an `__info` argument + info_arg: Whether the function takes an `info` argument return_schema: Schema to use for serializing return value when_used: When the function should be called """ @@ -293,16 +293,16 @@ def plain_serializer_function_ser_schema( class SerializerFunctionWrapHandler(Protocol): # pragma: no cover - def __call__(self, __input_value: Any, __index_key: int | str | None = None) -> Any: ... + def __call__(self, input_value: Any, index_key: int | str | None = None, /) -> Any: ... -# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any +# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any GeneralWrapNoInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler], Any] -# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any +# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any GeneralWrapInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any] -# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any +# (model: Any, input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any FieldWrapNoInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler], Any] -# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: FieldSerializationInfo) -> Any +# (model: Any, input_value: Any, serializer: SerializerFunctionWrapHandler, info: FieldSerializationInfo, /) -> Any FieldWrapInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler, FieldSerializationInfo], Any] WrapSerializerFunction = Union[ GeneralWrapNoInfoSerializerFunction, @@ -338,7 +338,7 @@ def wrap_serializer_function_ser_schema( function: The function to use for serialization is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, and `info` includes `field_name` - info_arg: Whether the function takes an `__info` argument + info_arg: Whether the function takes an `info` argument schema: The schema to use for the inner serialization return_schema: Schema to use for serializing return value when_used: When the function should be called @@ -1767,7 +1767,7 @@ def dict_schema( ) -# (__input_value: Any) -> Any +# (input_value: Any, /) -> Any NoInfoValidatorFunction = Callable[[Any], Any] @@ -1776,7 +1776,7 @@ class NoInfoValidatorFunctionSchema(TypedDict): function: NoInfoValidatorFunction -# (__input_value: Any, __info: ValidationInfo) -> Any +# (input_value: Any, info: ValidationInfo, /) -> Any WithInfoValidatorFunction = Callable[[Any, ValidationInfo], Any] @@ -1990,7 +1990,7 @@ def __call__(self, input_value: Any, outer_location: str | int | None = None) -> ... -# (__input_value: Any, __validator: ValidatorFunctionWrapHandler) -> Any +# (input_value: Any, validator: ValidatorFunctionWrapHandler, /) -> Any NoInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler], Any] @@ -1999,7 +1999,7 @@ class NoInfoWrapValidatorFunctionSchema(TypedDict): function: NoInfoWrapValidatorFunction -# (__input_value: Any, __validator: ValidatorFunctionWrapHandler, __info: ValidationInfo) -> Any +# (input_value: Any, validator: ValidatorFunctionWrapHandler, info: ValidationInfo, /) -> Any WithInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, ValidationInfo], Any] diff --git a/tests/test_typing.py b/tests/test_typing.py index 00f58e31f..63d17a790 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -221,8 +221,8 @@ def test_type_error(): def test_ser_function_plain(): - def f(__input: Any, __info: core_schema.SerializationInfo) -> str: - return str(__info) + def f(input: Any, info: core_schema.SerializationInfo, /) -> str: + return str(info) s = SchemaSerializer( core_schema.any_schema( @@ -239,9 +239,9 @@ def f(__input: Any, __info: core_schema.SerializationInfo) -> str: def test_ser_function_wrap(): def f( - __input: Any, __serialize: core_schema.SerializerFunctionWrapHandler, __info: core_schema.SerializationInfo + input: Any, serialize: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo, / ) -> str: - return f'{__serialize} {__info}' + return f'{serialize} {info}' s = SchemaSerializer( core_schema.any_schema( diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index e21f50008..8a22d96c3 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -32,8 +32,8 @@ def __iter__(self): def __len__(self) -> int: return len(self._d) - def __getitem__(self, __k): - return self._d[__k] + def __getitem__(self, k, /): + return self._d[k] def __repr__(self): return 'Map({})'.format(', '.join(f'{k}={v!r}' for k, v in self._d.items())) @@ -1188,9 +1188,9 @@ class Source: a = 1 b = 2 - def __getattribute__(self, __name: str) -> Any: - accessed.append(__name) - return super().__getattribute__(__name) + def __getattribute__(self, name: str, /) -> Any: + accessed.append(name) + return super().__getattribute__(name) assert v.validate_python(Source()) == ({'a': 1}, None, {'a'}) assert 'a' in accessed and 'b' not in accessed diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 8fb25cff6..dc18cd86e 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -32,8 +32,8 @@ def __iter__(self): def __len__(self) -> int: return len(self._d) - def __getitem__(self, __k): - return self._d[__k] + def __getitem__(self, k, /): + return self._d[k] def __repr__(self): return 'Map({})'.format(', '.join(f'{k}={v!r}' for k, v in self._d.items())) From c6301fe5966b98b11970df49129e34f76616c9dc Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:56:21 -0600 Subject: [PATCH 225/550] Fix parsing BigInt from str (#1204) --- src/input/input_json.rs | 5 +---- src/input/input_string.rs | 7 ++----- tests/validators/test_int.py | 8 ++++++++ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/input/input_json.rs b/src/input/input_json.rs index cd4ac919c..dde73cf65 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -370,10 +370,7 @@ impl<'a> Input<'a> for String { } fn validate_int(&'a self, _strict: bool) -> ValResult>> { - match self.parse() { - Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))), - Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), - } + str_as_int(self, self).map(ValidationMatch::lax) } fn validate_float(&'a self, _strict: bool) -> ValResult>> { diff --git a/src/input/input_string.rs b/src/input/input_string.rs index cd5931b24..f6c33d784 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -11,7 +11,7 @@ use crate::validators::decimal::create_decimal; use super::datetime::{ bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, EitherDate, EitherDateTime, EitherTime, }; -use super::shared::{str_as_bool, str_as_float}; +use super::shared::{str_as_bool, str_as_float, str_as_int}; use super::{ BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping, Input, ValidationMatch, @@ -112,10 +112,7 @@ impl<'a> Input<'a> for StringMapping<'a> { fn validate_int(&'a self, _strict: bool) -> ValResult>> { match self { - Self::String(s) => match py_string_str(s)?.parse() { - Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))), - Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)), - }, + Self::String(s) => str_as_int(self, py_string_str(s)?).map(ValidationMatch::strict), Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::IntType, self)), } } diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index 24928718d..c5ccd7957 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -504,3 +504,11 @@ def test_allow_inf_nan_false_json() -> None: v.validate_json('Infinity') with pytest.raises(ValidationError, match=r'Input should be a finite number \[type=finite_number'): v.validate_json('-Infinity') + + +def test_json_big_int_key(): + v = SchemaValidator({'type': 'dict', 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'str'}}) + big_integer = 1433352099889938534014333520998899385340 + assert v.validate_python({big_integer: 'x'}) == {big_integer: 'x'} + assert v.validate_json('{"' + str(big_integer) + '": "x"}') == {big_integer: 'x'} + assert v.validate_strings({str(big_integer): 'x'}) == {big_integer: 'x'} From 1083986b5e948778cd09e01a6e05c7b2fc670d8a Mon Sep 17 00:00:00 2001 From: ornariece Date: Wed, 6 Mar 2024 11:58:44 +0100 Subject: [PATCH 226/550] ability to pass context to serialization (pydantic#7143) (#1215) Co-authored-by: ornariece <37-ornariece@users.noreply.git.malined.com> --- python/pydantic_core/_pydantic_core.pyi | 22 +++++-- python/pydantic_core/core_schema.py | 4 ++ src/errors/validation_exception.rs | 2 +- src/serializers/extra.rs | 12 +++- src/serializers/infer.rs | 2 + src/serializers/mod.rs | 18 ++++-- src/serializers/type_serializers/function.rs | 61 ++++++++++++++++++- src/serializers/type_serializers/generator.rs | 29 +++++++++ src/validators/function.rs | 12 ++++ tests/serializers/test_functions.py | 26 ++++++-- tests/test.rs | 6 +- tests/test_typing.py | 4 +- 12 files changed, 177 insertions(+), 21 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index a6001f63c..347589fd4 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -101,7 +101,7 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: 'dict[str, Any] | None' = None, + context: dict[str, Any] | None = None, self_instance: Any | None = None, ) -> Any: """ @@ -131,7 +131,7 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: 'dict[str, Any] | None' = None, + context: dict[str, Any] | None = None, self_instance: Any | None = None, ) -> bool: """ @@ -148,7 +148,7 @@ class SchemaValidator: input: str | bytes | bytearray, *, strict: bool | None = None, - context: 'dict[str, Any] | None' = None, + context: dict[str, Any] | None = None, self_instance: Any | None = None, ) -> Any: """ @@ -176,7 +176,7 @@ class SchemaValidator: The validated Python object. """ def validate_strings( - self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None + self, input: _StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None ) -> Any: """ Validate a string against the schema and return the validated Python object. @@ -206,7 +206,7 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: 'dict[str, Any] | None' = None, + context: dict[str, Any] | None = None, ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]: """ Validate an assignment to a field on a model. @@ -278,6 +278,7 @@ class SchemaSerializer: round_trip: bool = False, warnings: bool = True, fallback: Callable[[Any], Any] | None = None, + context: dict[str, Any] | None = None, ) -> Any: """ Serialize/marshal a Python object to a Python object including transforming and filtering data. @@ -297,6 +298,8 @@ class SchemaSerializer: warnings: Whether to log warnings when invalid fields are encountered. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + context: The context to use for serialization, this is passed to functional serializers as + [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -318,6 +321,7 @@ class SchemaSerializer: round_trip: bool = False, warnings: bool = True, fallback: Callable[[Any], Any] | None = None, + context: dict[str, Any] | None = None, ) -> bytes: """ Serialize a Python object to JSON including transforming and filtering data. @@ -336,6 +340,8 @@ class SchemaSerializer: warnings: Whether to log warnings when invalid fields are encountered. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + context: The context to use for serialization, this is passed to functional serializers as + [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -358,6 +364,7 @@ def to_json( inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, + context: dict[str, Any] | None = None, ) -> bytes: """ Serialize a Python object to JSON including transforming and filtering data. @@ -379,6 +386,8 @@ def to_json( `""` will be used. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + context: The context to use for serialization, this is passed to functional serializers as + [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -419,6 +428,7 @@ def to_jsonable_python( inf_nan_mode: Literal['null', 'constants'] = 'constants', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, + context: dict[str, Any] | None = None, ) -> Any: """ Serialize/marshal a Python object to a JSON-serializable Python object including transforming and filtering data. @@ -440,6 +450,8 @@ def to_jsonable_python( `""` will be used. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + context: The context to use for serialization, this is passed to functional serializers as + [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index ba8ce15ad..8e36e8da0 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -122,6 +122,10 @@ def include(self) -> IncExCall: ... @property def exclude(self) -> IncExCall: ... + @property + def context(self) -> Any | None: + """Current serialization context.""" + @property def mode(self) -> str: ... diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 24626df29..e87f4b0d9 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -320,7 +320,7 @@ impl ValidationError { include_input: bool, ) -> PyResult<&'py PyString> { let state = SerializationState::new("iso8601", "utf8", "constants")?; - let extra = state.extra(py, &SerMode::Json, true, false, false, true, None); + let extra = state.extra(py, &SerMode::Json, true, false, false, true, None, None); let serializer = ValidationErrorSerializer { py, line_errors: &self.line_errors, diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 8d598d46b..3cff3f370 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -45,6 +45,7 @@ impl SerializationState { round_trip: bool, serialize_unknown: bool, fallback: Option<&'py PyAny>, + context: Option<&'py PyAny>, ) -> Extra<'py> { Extra::new( py, @@ -59,6 +60,7 @@ impl SerializationState { &self.rec_guard, serialize_unknown, fallback, + context, ) } @@ -90,6 +92,7 @@ pub(crate) struct Extra<'a> { pub field_name: Option<&'a str>, pub serialize_unknown: bool, pub fallback: Option<&'a PyAny>, + pub context: Option<&'a PyAny>, } impl<'a> Extra<'a> { @@ -107,6 +110,7 @@ impl<'a> Extra<'a> { rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, + context: Option<&'a PyAny>, ) -> Self { Self { mode, @@ -124,6 +128,7 @@ impl<'a> Extra<'a> { field_name: None, serialize_unknown, fallback, + context, } } @@ -178,10 +183,11 @@ pub(crate) struct ExtraOwned { config: SerializationConfig, rec_guard: SerRecursionState, check: SerCheck, - model: Option, + pub model: Option, field_name: Option, serialize_unknown: bool, - fallback: Option, + pub fallback: Option, + pub context: Option, } impl ExtraOwned { @@ -201,6 +207,7 @@ impl ExtraOwned { field_name: extra.field_name.map(ToString::to_string), serialize_unknown: extra.serialize_unknown, fallback: extra.fallback.map(Into::into), + context: extra.context.map(Into::into), } } @@ -221,6 +228,7 @@ impl ExtraOwned { field_name: self.field_name.as_deref(), serialize_unknown: self.serialize_unknown, fallback: self.fallback.as_ref().map(|m| m.as_ref(py)), + context: self.context.as_ref().map(|m| m.as_ref(py)), } } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index edbe614e4..98ffc72b2 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -99,6 +99,7 @@ pub(crate) fn infer_to_python_known( extra.rec_guard, extra.serialize_unknown, extra.fallback, + extra.context, ); serializer.serializer.to_python(value, include, exclude, &extra) }; @@ -468,6 +469,7 @@ pub(crate) fn infer_serialize_known( extra.rec_guard, extra.serialize_unknown, extra.fallback, + extra.context, ); let pydantic_serializer = PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 7d9c5347c..3ed47a0a8 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -55,6 +55,7 @@ impl SchemaSerializer { rec_guard: &'a SerRecursionState, serialize_unknown: bool, fallback: Option<&'a PyAny>, + context: Option<&'a PyAny>, ) -> Extra<'b> { Extra::new( py, @@ -69,6 +70,7 @@ impl SchemaSerializer { rec_guard, serialize_unknown, fallback, + context, ) } } @@ -95,7 +97,7 @@ impl SchemaSerializer { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (value, *, mode = None, include = None, exclude = None, by_alias = true, exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true, - fallback = None))] + fallback = None, context = None))] pub fn to_python( &self, py: Python, @@ -110,6 +112,7 @@ impl SchemaSerializer { round_trip: bool, warnings: bool, fallback: Option<&PyAny>, + context: Option<&PyAny>, ) -> PyResult { let mode: SerMode = mode.into(); let warnings = CollectWarnings::new(warnings); @@ -126,6 +129,7 @@ impl SchemaSerializer { &rec_guard, false, fallback, + context, ); let v = self.serializer.to_python(value, include, exclude, &extra)?; warnings.final_check(py)?; @@ -135,7 +139,7 @@ impl SchemaSerializer { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true, exclude_unset = false, exclude_defaults = false, exclude_none = false, round_trip = false, warnings = true, - fallback = None))] + fallback = None, context = None))] pub fn to_json( &self, py: Python, @@ -150,6 +154,7 @@ impl SchemaSerializer { round_trip: bool, warnings: bool, fallback: Option<&PyAny>, + context: Option<&PyAny>, ) -> PyResult { let warnings = CollectWarnings::new(warnings); let rec_guard = SerRecursionState::default(); @@ -165,6 +170,7 @@ impl SchemaSerializer { &rec_guard, false, fallback, + context, ); let bytes = to_json_bytes( value, @@ -213,7 +219,7 @@ impl SchemaSerializer { #[pyfunction] #[pyo3(signature = (value, *, indent = None, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, timedelta_mode = "iso8601", bytes_mode = "utf8", - inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] + inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))] pub fn to_json( py: Python, value: &PyAny, @@ -228,6 +234,7 @@ pub fn to_json( inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, + context: Option<&PyAny>, ) -> PyResult { let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( @@ -238,6 +245,7 @@ pub fn to_json( round_trip, serialize_unknown, fallback, + context, ); let serializer = type_serializers::any::AnySerializer.into(); let bytes = to_json_bytes(value, &serializer, include, exclude, &extra, indent, 1024)?; @@ -249,7 +257,7 @@ pub fn to_json( #[allow(clippy::too_many_arguments)] #[pyfunction] #[pyo3(signature = (value, *, include = None, exclude = None, by_alias = true, exclude_none = false, round_trip = false, - timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None))] + timedelta_mode = "iso8601", bytes_mode = "utf8", inf_nan_mode = "constants", serialize_unknown = false, fallback = None, context = None))] pub fn to_jsonable_python( py: Python, value: &PyAny, @@ -263,6 +271,7 @@ pub fn to_jsonable_python( inf_nan_mode: &str, serialize_unknown: bool, fallback: Option<&PyAny>, + context: Option<&PyAny>, ) -> PyResult { let state = SerializationState::new(timedelta_mode, bytes_mode, inf_nan_mode)?; let extra = state.extra( @@ -273,6 +282,7 @@ pub fn to_jsonable_python( round_trip, serialize_unknown, fallback, + context, ); let v = infer::infer_to_python(value, include, exclude, &extra)?; state.final_check(py)?; diff --git a/src/serializers/type_serializers/function.rs b/src/serializers/type_serializers/function.rs index cf4ef7710..170009383 100644 --- a/src/serializers/type_serializers/function.rs +++ b/src/serializers/type_serializers/function.rs @@ -1,9 +1,11 @@ use std::borrow::Cow; use pyo3::exceptions::{PyAttributeError, PyRecursionError, PyRuntimeError}; +use pyo3::gc::PyVisit; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; +use pyo3::PyTraverseError; use pyo3::types::PyString; @@ -440,6 +442,33 @@ impl SerializationCallable { exclude: exclude.map(|v| v.into_py(py)), } } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(include) = &self.include { + visit.call(include)?; + } + if let Some(exclude) = &self.exclude { + visit.call(exclude)?; + } + if let Some(model) = &self.extra_owned.model { + visit.call(model)?; + } + if let Some(fallback) = &self.extra_owned.fallback { + visit.call(fallback)?; + } + if let Some(context) = &self.extra_owned.context { + visit.call(context)?; + } + Ok(()) + } + + fn __clear__(&mut self) { + self.include = None; + self.exclude = None; + self.extra_owned.model = None; + self.extra_owned.fallback = None; + self.extra_owned.context = None; + } } #[pymethods] @@ -488,6 +517,8 @@ struct SerializationInfo { include: Option, #[pyo3(get)] exclude: Option, + #[pyo3(get)] + context: Option, _mode: SerMode, #[pyo3(get)] by_alias: bool, @@ -515,6 +546,7 @@ impl SerializationInfo { Some(field_name) => Ok(Self { include: include.map(|i| i.into_py(py)), exclude: exclude.map(|e| e.into_py(py)), + context: extra.context.map(Into::into), _mode: extra.mode.clone(), by_alias: extra.by_alias, exclude_unset: extra.exclude_unset, @@ -531,6 +563,7 @@ impl SerializationInfo { Ok(Self { include: include.map(|i| i.into_py(py)), exclude: exclude.map(|e| e.into_py(py)), + context: extra.context.map(Into::into), _mode: extra.mode.clone(), by_alias: extra.by_alias, exclude_unset: extra.exclude_unset, @@ -541,6 +574,25 @@ impl SerializationInfo { }) } } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(include) = &self.include { + visit.call(include)?; + } + if let Some(exclude) = &self.exclude { + visit.call(exclude)?; + } + if let Some(context) = &self.context { + visit.call(context)?; + } + Ok(()) + } + + fn __clear__(&mut self) { + self.include = None; + self.exclude = None; + self.context = None; + } } #[pymethods] @@ -563,6 +615,9 @@ impl SerializationInfo { if let Some(ref exclude) = self.exclude { d.set_item("exclude", exclude)?; } + if let Some(ref context) = self.context { + d.set_item("context", context)?; + } d.set_item("mode", self.mode(py))?; d.set_item("by_alias", self.by_alias)?; d.set_item("exclude_unset", self.exclude_unset)?; @@ -574,7 +629,7 @@ impl SerializationInfo { fn __repr__(&self, py: Python) -> PyResult { Ok(format!( - "SerializationInfo(include={}, exclude={}, mode='{}', by_alias={}, exclude_unset={}, exclude_defaults={}, exclude_none={}, round_trip={})", + "SerializationInfo(include={}, exclude={}, context={}, mode='{}', by_alias={}, exclude_unset={}, exclude_defaults={}, exclude_none={}, round_trip={})", match self.include { Some(ref include) => include.as_ref(py).repr()?.to_str()?, None => "None", @@ -583,6 +638,10 @@ impl SerializationInfo { Some(ref exclude) => exclude.as_ref(py).repr()?.to_str()?, None => "None", }, + match self.context { + Some(ref context) => context.as_ref(py).repr()?.to_str()?, + None => "None", + }, self._mode, py_bool(self.by_alias), py_bool(self.exclude_unset), diff --git a/src/serializers/type_serializers/generator.rs b/src/serializers/type_serializers/generator.rs index 48a01e939..4637bc44e 100644 --- a/src/serializers/type_serializers/generator.rs +++ b/src/serializers/type_serializers/generator.rs @@ -1,8 +1,10 @@ use std::borrow::Cow; +use pyo3::gc::PyVisit; use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyIterator}; +use pyo3::PyTraverseError; use serde::ser::SerializeSeq; @@ -173,6 +175,33 @@ impl SerializationIterator { exclude: exclude.map(|v| v.into_py(py)), } } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + if let Some(include) = &self.include { + visit.call(include)?; + } + if let Some(exclude) = &self.exclude { + visit.call(exclude)?; + } + if let Some(model) = &self.extra_owned.model { + visit.call(model)?; + } + if let Some(fallback) = &self.extra_owned.fallback { + visit.call(fallback)?; + } + if let Some(context) = &self.extra_owned.context { + visit.call(context)?; + } + Ok(()) + } + + fn __clear__(&mut self) { + self.include = None; + self.exclude = None; + self.extra_owned.model = None; + self.extra_owned.fallback = None; + self.extra_owned.context = None; + } } #[pymethods] diff --git a/src/validators/function.rs b/src/validators/function.rs index 34e0b6327..51195fbe4 100644 --- a/src/validators/function.rs +++ b/src/validators/function.rs @@ -480,6 +480,18 @@ impl ValidationInfo { mode: extra.input_type, } } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.config)?; + if let Some(context) = &self.context { + visit.call(context)?; + } + Ok(()) + } + + fn __clear__(&mut self) { + self.context = None; + } } #[pymethods] diff --git a/tests/serializers/test_functions.py b/tests/serializers/test_functions.py index 8851a7d36..e9a59bad2 100644 --- a/tests/serializers/test_functions.py +++ b/tests/serializers/test_functions.py @@ -122,6 +122,18 @@ def double(value, info): 'round_trip': False, } + assert s.to_python(1, context='context') == 2 + # insert_assert(f_info) + assert f_info == { + 'context': 'context', + 'mode': 'python', + 'by_alias': True, + 'exclude_unset': False, + 'exclude_defaults': False, + 'exclude_none': False, + 'round_trip': False, + } + def test_function_error(): def raise_error(value, _info): @@ -212,23 +224,27 @@ def append_args(value, info): ) ) assert s.to_python(123) == ( - "123 info=SerializationInfo(include=None, exclude=None, mode='python', by_alias=True, exclude_unset=False, " + "123 info=SerializationInfo(include=None, exclude=None, context=None, mode='python', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_python(123, mode='other') == ( - "123 info=SerializationInfo(include=None, exclude=None, mode='other', by_alias=True, exclude_unset=False, " + "123 info=SerializationInfo(include=None, exclude=None, context=None, mode='other', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_python(123, include={'x'}) == ( - "123 info=SerializationInfo(include={'x'}, exclude=None, mode='python', by_alias=True, exclude_unset=False, " + "123 info=SerializationInfo(include={'x'}, exclude=None, context=None, mode='python', by_alias=True, exclude_unset=False, " + 'exclude_defaults=False, exclude_none=False, round_trip=False)' + ) + assert s.to_python(123, context='context') == ( + "123 info=SerializationInfo(include=None, exclude=None, context='context', mode='python', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_python(123, mode='json', exclude={1: {2}}) == ( - "123 info=SerializationInfo(include=None, exclude={1: {2}}, mode='json', by_alias=True, exclude_unset=False, " + "123 info=SerializationInfo(include=None, exclude={1: {2}}, context=None, mode='json', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) assert s.to_json(123) == ( - b"\"123 info=SerializationInfo(include=None, exclude=None, mode='json', by_alias=True, exclude_unset=False, " + b"\"123 info=SerializationInfo(include=None, exclude=None, context=None, mode='json', by_alias=True, exclude_unset=False, " b'exclude_defaults=False, exclude_none=False, round_trip=False)"' ) diff --git a/tests/test.rs b/tests/test.rs index 4a5776e7d..597084b3b 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -79,7 +79,9 @@ a = A() let schema: &PyDict = locals.get_item("schema").unwrap().unwrap().extract().unwrap(); let serialized: Vec = SchemaSerializer::py_new(py, schema, None) .unwrap() - .to_json(py, a, None, None, None, true, false, false, false, false, true, None) + .to_json( + py, a, None, None, None, true, false, false, false, false, true, None, None, + ) .unwrap() .extract(py) .unwrap(); @@ -187,6 +189,7 @@ dump_json_input_2 = {'a': 'something'} false, false, None, + None, ) .unwrap(); let serialization_result: &PyAny = binding.extract(py).unwrap(); @@ -208,6 +211,7 @@ dump_json_input_2 = {'a': 'something'} false, false, None, + None, ) .unwrap(); let serialization_result: &PyAny = binding.extract(py).unwrap(); diff --git a/tests/test_typing.py b/tests/test_typing.py index 63d17a790..22ba1f6e3 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -232,7 +232,7 @@ def f(input: Any, info: core_schema.SerializationInfo, /) -> str: ) ) assert s.to_python(123) == ( - "SerializationInfo(include=None, exclude=None, mode='python', by_alias=True, exclude_unset=False, " + "SerializationInfo(include=None, exclude=None, context=None, mode='python', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) @@ -253,7 +253,7 @@ def f( # insert_assert(s.to_python(123, mode='json')) assert s.to_python(123, mode='json') == ( 'SerializationCallable(serializer=str) ' - "SerializationInfo(include=None, exclude=None, mode='json', by_alias=True, exclude_unset=False, " + "SerializationInfo(include=None, exclude=None, context=None, mode='json', by_alias=True, exclude_unset=False, " 'exclude_defaults=False, exclude_none=False, round_trip=False)' ) From 4a533aa93da212e959572e28866afe551f811fa2 Mon Sep 17 00:00:00 2001 From: "Bernhard M. Wiedemann" Date: Tue, 12 Mar 2024 12:41:47 +0100 Subject: [PATCH 227/550] Make tests pass in 2032 (#1221) --- tests/benchmarks/test_micro_benchmarks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 6086c5bd7..80b3e6a47 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -635,13 +635,13 @@ def test_date_from_datetime_str(self, benchmark, validator): def test_core_future(self, benchmark): v = SchemaValidator({'type': 'date', 'gt': date.today()}) - benchmark(v.validate_python, date(2032, 1, 1)) + benchmark(v.validate_python, date(2932, 1, 1)) @pytest.mark.benchmark(group='date future') def test_core_future_str(self, benchmark): v = SchemaValidator({'type': 'date', 'gt': date.today()}) - benchmark(v.validate_python, str(date(2032, 1, 1))) + benchmark(v.validate_python, str(date(2932, 1, 1))) class TestBenchmarkUnion: From 71d54a243f52550847716cb4208944667ebc16ee Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 12 Mar 2024 19:57:39 +0000 Subject: [PATCH 228/550] update to PyO3 0.21 beta (#1222) --- Cargo.lock | 60 +---- Cargo.toml | 8 +- benches/main.rs | 204 +++++++------- src/argument_markers.rs | 8 +- src/build_tools.rs | 28 +- src/errors/line_error.rs | 13 +- src/errors/location.rs | 10 +- src/errors/mod.rs | 4 +- src/errors/types.rs | 27 +- src/errors/validation_exception.rs | 45 ++-- src/errors/value_exception.rs | 18 +- src/input/datetime.rs | 91 ++++--- src/input/input_abstract.rs | 18 +- src/input/input_json.rs | 16 +- src/input/input_python.rs | 222 +++++++-------- src/input/input_string.rs | 28 +- src/input/return_enums.rs | 248 ++++++++--------- src/input/shared.rs | 16 +- src/lib.rs | 11 +- src/lookup_key.rs | 93 +++---- src/serializers/computed_fields.rs | 66 ++--- src/serializers/config.rs | 33 ++- src/serializers/extra.rs | 45 ++-- src/serializers/fields.rs | 138 +++++----- src/serializers/filter.rs | 136 +++++----- src/serializers/infer.rs | 252 +++++++++--------- src/serializers/mod.rs | 56 ++-- src/serializers/ob_type.rs | 119 +++++---- src/serializers/shared.rs | 97 +++---- src/serializers/type_serializers/any.rs | 18 +- src/serializers/type_serializers/bytes.rs | 18 +- src/serializers/type_serializers/dataclass.rs | 61 +++-- .../type_serializers/datetime_etc.rs | 30 +-- src/serializers/type_serializers/decimal.rs | 18 +- .../type_serializers/definitions.rs | 40 +-- src/serializers/type_serializers/dict.rs | 58 ++-- src/serializers/type_serializers/float.rs | 18 +- src/serializers/type_serializers/format.rs | 86 +++--- src/serializers/type_serializers/function.rs | 166 ++++++------ src/serializers/type_serializers/generator.rs | 67 +++-- src/serializers/type_serializers/json.rs | 22 +- .../type_serializers/json_or_python.rs | 26 +- src/serializers/type_serializers/list.rs | 38 ++- src/serializers/type_serializers/literal.rs | 42 +-- src/serializers/type_serializers/model.rs | 77 +++--- src/serializers/type_serializers/nullable.rs | 22 +- src/serializers/type_serializers/other.rs | 42 +-- .../type_serializers/set_frozenset.rs | 30 +-- src/serializers/type_serializers/simple.rs | 44 +-- src/serializers/type_serializers/string.rs | 28 +- src/serializers/type_serializers/timedelta.rs | 18 +- src/serializers/type_serializers/tuple.rs | 48 ++-- .../type_serializers/typed_dict.rs | 16 +- src/serializers/type_serializers/union.rs | 32 +-- src/serializers/type_serializers/url.rs | 18 +- src/serializers/type_serializers/uuid.rs | 20 +- .../type_serializers/with_default.rs | 22 +- src/tools.rs | 59 ++-- src/url.rs | 32 +-- src/validators/any.rs | 4 +- src/validators/arguments.rs | 59 ++-- src/validators/bool.rs | 4 +- src/validators/bytes.rs | 6 +- src/validators/call.rs | 30 ++- src/validators/callable.rs | 4 +- src/validators/chain.rs | 16 +- src/validators/custom_error.rs | 12 +- src/validators/dataclass.rs | 144 +++++----- src/validators/date.rs | 10 +- src/validators/datetime.rs | 36 ++- src/validators/decimal.rs | 81 +++--- src/validators/definitions.rs | 29 +- src/validators/dict.rs | 16 +- src/validators/float.rs | 12 +- src/validators/frozenset.rs | 4 +- src/validators/function.rs | 204 ++++++++------ src/validators/generator.rs | 30 ++- src/validators/int.rs | 6 +- src/validators/is_instance.rs | 15 +- src/validators/is_subclass.rs | 8 +- src/validators/json.rs | 8 +- src/validators/json_or_python.rs | 12 +- src/validators/lax_or_strict.rs | 8 +- src/validators/list.rs | 10 +- src/validators/literal.rs | 18 +- src/validators/mod.rs | 127 ++++----- src/validators/model.rs | 125 ++++----- src/validators/model_fields.rs | 71 ++--- src/validators/none.rs | 4 +- src/validators/nullable.rs | 8 +- src/validators/set.rs | 10 +- src/validators/string.rs | 32 ++- src/validators/time.rs | 12 +- src/validators/timedelta.rs | 12 +- src/validators/tuple.rs | 28 +- src/validators/typed_dict.rs | 40 +-- src/validators/union.rs | 100 +++---- src/validators/url.rs | 10 +- src/validators/uuid.rs | 28 +- src/validators/validation_state.rs | 9 +- src/validators/with_default.rs | 26 +- tests/test.rs | 72 +++-- 102 files changed, 2520 insertions(+), 2306 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b74074c8d..667d15b1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,32 +148,19 @@ checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "jiter" -version = "0.0.6" +version = "0.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87db066a99f69382be06d02313f8ce989996b53a04a8a70cfd1a6483a56227f7" +checksum = "c2a1b6e316923afd3087ec73829f646a67c18f3a5bd61624247b05e652e4a99d" dependencies = [ "ahash", "hashbrown", - "lexical-core", + "lexical-parse-float", "num-bigint", "num-traits", "pyo3", "smallvec", ] -[[package]] -name = "lexical-core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" -dependencies = [ - "lexical-parse-float", - "lexical-parse-integer", - "lexical-util", - "lexical-write-float", - "lexical-write-integer", -] - [[package]] name = "lexical-parse-float" version = "0.8.5" @@ -204,27 +191,6 @@ dependencies = [ "static_assertions", ] -[[package]] -name = "lexical-write-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" -dependencies = [ - "lexical-util", - "lexical-write-integer", - "static_assertions", -] - -[[package]] -name = "lexical-write-integer" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" -dependencies = [ - "lexical-util", - "static_assertions", -] - [[package]] name = "libc" version = "0.2.147" @@ -363,9 +329,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.3" +version = "0.21.0-beta.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" +checksum = "5d0c41d899f822e5f39186d6da130a822a0a43edb19992b51bf4ef6cd0b4cfd1" dependencies = [ "cfg-if", "indoc", @@ -382,9 +348,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.3" +version = "0.21.0-beta.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" +checksum = "5509c2aa78c7e770077e41ba86f806e60dcee812e924ccb2d6fe78c0a0128ce2" dependencies = [ "once_cell", "python3-dll-a", @@ -393,9 +359,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.3" +version = "0.21.0-beta.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" +checksum = "e6bb234a86ed619a661f3bb3c2493aaff9cb937e33e198d17f5f20a15881e155" dependencies = [ "libc", "pyo3-build-config", @@ -403,9 +369,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.3" +version = "0.21.0-beta.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" +checksum = "f0b787de2c6832eb1eb393c9f82f976a5a87bda979780d9b853878846a8d2e4b" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -415,9 +381,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.3" +version = "0.21.0-beta.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" +checksum = "5e3b7beed357786d2afe845871964e824ad8af0df38a403f7d01cdc81aadb211" dependencies = [ "heck", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index ea3e3d73c..f74976967 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ include = [ rust-version = "1.70" [dependencies] -pyo3 = { version = "0.20.3", features = ["generate-import-lib", "num-bigint"] } +pyo3 = { version = "0.21.0-beta.0", features = ["generate-import-lib", "num-bigint"] } regex = "1.10.3" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.26.1" @@ -44,7 +44,7 @@ base64 = "0.21.7" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.7.0" -jiter = {version = "0.0.6", features = ["python"]} +jiter = {version = "0.0.7", features = ["python"]} [lib] name = "_pydantic_core" @@ -71,12 +71,12 @@ debug = true strip = false [dev-dependencies] -pyo3 = { version = "0.20.3", features = ["auto-initialize"] } +pyo3 = { version = "0.21.0-beta.0", features = ["auto-initialize"] } [build-dependencies] version_check = "0.9.4" # used where logic has to be version/distribution specific, e.g. pypy -pyo3-build-config = { version = "0.20.2" } +pyo3-build-config = { version = "0.21.0-beta.0" } [lints.clippy] dbg_macro = "warn" diff --git a/benches/main.rs b/benches/main.rs index 4b8a2b106..d1ef27a93 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -10,17 +10,17 @@ use pyo3::types::{PyDict, PyString}; use _pydantic_core::{validate_core_schema, SchemaValidator}; fn build_schema_validator_with_globals(py: Python, code: &str, globals: Option<&PyDict>) -> SchemaValidator { - let mut schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap(); - schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); - SchemaValidator::py_new(py, schema, None).unwrap() + let mut schema = py.eval(code, globals, None).unwrap().extract().unwrap(); + schema = validate_core_schema(&schema, None).unwrap().extract().unwrap(); + SchemaValidator::py_new(py, &schema, None).unwrap() } fn build_schema_validator(py: Python, code: &str) -> SchemaValidator { build_schema_validator_with_globals(py, code, None) } -fn json<'a>(py: Python<'a>, code: &'a str) -> &'a PyAny { - black_box(PyString::new(py, code)) +fn json<'a>(py: Python<'a>, code: &'a str) -> Bound<'a, PyAny> { + black_box(PyString::new_bound(py, code).into_any()) } #[bench] @@ -28,11 +28,11 @@ fn ints_json(bench: &mut Bencher) { Python::with_gil(|py| { let validator = build_schema_validator(py, "{'type': 'int'}"); - let result = validator.validate_json(py, json(py, "123"), None, None, None).unwrap(); + let result = validator.validate_json(py, &json(py, "123"), None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 123); - bench.iter(|| black_box(validator.validate_json(py, json(py, "123"), None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &json(py, "123"), None, None, None).unwrap())) }) } @@ -41,14 +41,13 @@ fn ints_python(bench: &mut Bencher) { Python::with_gil(|py| { let validator = build_schema_validator(py, "{'type': 'int'}"); - let input = 123_i64.into_py(py); - let input = input.as_ref(py); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = 123_i64.into_py(py).into_bound(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 123); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -61,7 +60,7 @@ fn list_int_json(bench: &mut Bencher) { (0..100).map(|x| x.to_string()).collect::>().join(",") ); - bench.iter(|| black_box(validator.validate_json(py, json(py, &code), None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &json(py, &code), None, None, None).unwrap())) }) } @@ -80,9 +79,9 @@ fn list_int_input(py: Python<'_>) -> (SchemaValidator, PyObject) { fn list_int_python(bench: &mut Bencher) { Python::with_gil(|py| { let (validator, input) = list_int_input(py); - let input = black_box(input.as_ref(py)); + let input = black_box(input.bind(py)); bench.iter(|| { - let v = validator.validate_python(py, input, None, None, None, None).unwrap(); + let v = validator.validate_python(py, &input, None, None, None, None).unwrap(); black_box(v) }) }) @@ -92,12 +91,12 @@ fn list_int_python(bench: &mut Bencher) { fn list_int_python_isinstance(bench: &mut Bencher) { Python::with_gil(|py| { let (validator, input) = list_int_input(py); - let input = black_box(input.as_ref(py)); - let v = validator.isinstance_python(py, input, None, None, None, None).unwrap(); + let input = black_box(input.bind(py)); + let v = validator.isinstance_python(py, &input, None, None, None, None).unwrap(); assert!(v); bench.iter(|| { - let v = validator.isinstance_python(py, input, None, None, None, None).unwrap(); + let v = validator.isinstance_python(py, &input, None, None, None, None).unwrap(); black_box(v) }) }) @@ -115,7 +114,7 @@ fn list_error_json(bench: &mut Bencher) { .join(", ") ); - match validator.validate_json(py, json(py, &code), None, None, None) { + match validator.validate_json(py, &json(py, &code), None, None, None) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); @@ -127,7 +126,7 @@ fn list_error_json(bench: &mut Bencher) { }; bench.iter( - || match validator.validate_json(py, json(py, &code), None, None, None) { + || match validator.validate_json(py, &json(py, &code), None, None, None) { Ok(_) => panic!("unexpectedly valid"), Err(e) => black_box(e), }, @@ -145,9 +144,9 @@ fn list_error_python_input(py: Python<'_>) -> (SchemaValidator, PyObject) { .join(", ") ); - let input = py.eval(&code, None, None).unwrap(); + let input = py.eval(&code, None, None).unwrap().extract().unwrap(); - match validator.validate_python(py, input, None, None, None, None) { + match validator.validate_python(py, &input, None, None, None, None) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); @@ -165,9 +164,9 @@ fn list_error_python(bench: &mut Bencher) { Python::with_gil(|py| { let (validator, input) = list_error_python_input(py); - let input = black_box(input.as_ref(py)); + let input = black_box(input.bind(py)); bench.iter(|| { - let result = validator.validate_python(py, input, None, None, None, None); + let result = validator.validate_python(py, &input, None, None, None, None); match result { Ok(_) => panic!("unexpectedly valid"), @@ -181,14 +180,12 @@ fn list_error_python(bench: &mut Bencher) { fn list_error_python_isinstance(bench: &mut Bencher) { Python::with_gil(|py| { let (validator, input) = list_error_python_input(py); - let r = validator - .isinstance_python(py, black_box(input.as_ref(py)), None, None, None, None) - .unwrap(); + let input = black_box(input.bind(py)); + let r = validator.isinstance_python(py, &input, None, None, None, None).unwrap(); assert!(!r); - let input = black_box(input.as_ref(py)); bench.iter(|| { - black_box(validator.isinstance_python(py, input, None, None, None, None).unwrap()); + black_box(validator.isinstance_python(py, &input, None, None, None, None).unwrap()); }) }) } @@ -202,7 +199,7 @@ fn list_any_json(bench: &mut Bencher) { (0..100).map(|x| x.to_string()).collect::>().join(",") ); - bench.iter(|| black_box(validator.validate_json(py, json(py, &code), None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &json(py, &code), None, None, None).unwrap())) }) } @@ -214,10 +211,10 @@ fn list_any_python(bench: &mut Bencher) { "[{}]", (0..100).map(|x| x.to_string()).collect::>().join(",") ); - let input = py.eval(&code, None, None).unwrap(); - let input = black_box(input); + let input = py.eval(&code, None, None).unwrap().to_object(py); + let input = black_box(input.bind(py)); bench.iter(|| { - let v = validator.validate_python(py, input, None, None, None, None).unwrap(); + let v = validator.validate_python(py, &input, None, None, None, None).unwrap(); black_box(v) }) }) @@ -247,7 +244,7 @@ fn dict_json(bench: &mut Bencher) { .join(", ") ); - bench.iter(|| black_box(validator.validate_json(py, json(py, &code), None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &json(py, &code), None, None, None).unwrap())) }) } @@ -266,10 +263,10 @@ fn dict_python(bench: &mut Bencher) { .collect::>() .join(", ") ); - let input = py.eval(&code, None, None).unwrap(); - let input = black_box(input); + let input = py.eval(&code, None, None).unwrap().to_object(py); + let input = black_box(input.bind(py)); bench.iter(|| { - let v = validator.validate_python(py, input, None, None, None, None).unwrap(); + let v = validator.validate_python(py, &input, None, None, None, None).unwrap(); black_box(v) }) }) @@ -295,9 +292,9 @@ fn dict_value_error(bench: &mut Bencher) { .join(", ") ); - let input = py.eval(&code, None, None).unwrap(); + let input = py.eval(&code, None, None).unwrap().to_object(py).into_bound(py); - match validator.validate_python(py, input, None, None, None, None) { + match validator.validate_python(py, &input, None, None, None, None) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); @@ -310,7 +307,7 @@ fn dict_value_error(bench: &mut Bencher) { let input = black_box(input); bench.iter(|| { - let result = validator.validate_python(py, input, None, None, None, None); + let result = validator.validate_python(py, &input, None, None, None, None); match result { Ok(_) => panic!("unexpectedly valid"), @@ -345,7 +342,7 @@ fn typed_dict_json(bench: &mut Bencher) { let code = r#"{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}"#.to_string(); - bench.iter(|| black_box(validator.validate_json(py, json(py, &code), None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &json(py, &code), None, None, None).unwrap())) }) } @@ -373,10 +370,10 @@ fn typed_dict_python(bench: &mut Bencher) { ); let code = r#"{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 0}"#.to_string(); - let input = py.eval(&code, None, None).unwrap(); - let input = black_box(input); + let input = py.eval(&code, None, None).unwrap().to_object(py); + let input = black_box(input.bind(py)); bench.iter(|| { - let v = validator.validate_python(py, input, None, None, None, None).unwrap(); + let v = validator.validate_python(py, &input, None, None, None, None).unwrap(); black_box(v) }) }) @@ -413,10 +410,10 @@ fn typed_dict_deep_error(bench: &mut Bencher) { let code = "{'field_a': '1', 'field_b': {'field_c': '2', 'field_d': {'field_e': '4', 'field_f': 'xx'}}}"; - let input = py.eval(code, None, None).unwrap(); - let input = black_box(input); + let input = py.eval(code, None, None).unwrap().to_object(py); + let input = black_box(input.bind(py)); - match validator.validate_python(py, input, None, None, None, None) { + match validator.validate_python(py, &input, None, None, None, None) { Ok(_) => panic!("unexpectedly valid"), Err(e) => { let v = e.value(py); @@ -428,7 +425,7 @@ fn typed_dict_deep_error(bench: &mut Bencher) { }; bench.iter(|| { - let result = validator.validate_python(py, input, None, None, None, None); + let result = validator.validate_python(py, &input, None, None, None, None); match result { Ok(_) => panic!("unexpectedly valid"), @@ -444,16 +441,16 @@ fn complete_model(bench: &mut Bencher) { let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); - let complete_schema = py.import("complete_schema").unwrap(); + let complete_schema = py.import_bound("complete_schema").unwrap(); let mut schema = complete_schema.call_method0("schema").unwrap(); - schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); - let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + schema = validate_core_schema(&schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, &schema, None).unwrap(); let input = complete_schema.call_method0("input_data_lax").unwrap(); let input = black_box(input); bench.iter(|| { - black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + black_box(validator.validate_python(py, &input, None, None, None, None).unwrap()); }) }) } @@ -464,18 +461,18 @@ fn nested_model_using_definitions(bench: &mut Bencher) { let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); - let complete_schema = py.import("nested_schema").unwrap(); + let complete_schema = py.import_bound("nested_schema").unwrap(); let mut schema = complete_schema.call_method0("schema_using_defs").unwrap(); - schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); - let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + schema = validate_core_schema(&schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, &schema, None).unwrap(); let input = complete_schema.call_method0("input_data_valid").unwrap(); let input = black_box(input); - validator.validate_python(py, input, None, None, None, None).unwrap(); + validator.validate_python(py, &input, None, None, None, None).unwrap(); bench.iter(|| { - black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + black_box(validator.validate_python(py, &input, None, None, None, None).unwrap()); }) }) } @@ -486,18 +483,18 @@ fn nested_model_inlined(bench: &mut Bencher) { let sys_path = py.import("sys").unwrap().getattr("path").unwrap(); sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap(); - let complete_schema = py.import("nested_schema").unwrap(); + let complete_schema = py.import_bound("nested_schema").unwrap(); let mut schema = complete_schema.call_method0("inlined_schema").unwrap(); - schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap(); - let validator = SchemaValidator::py_new(py, schema, None).unwrap(); + schema = validate_core_schema(&schema, None).unwrap().extract().unwrap(); + let validator = SchemaValidator::py_new(py, &schema, None).unwrap(); let input = complete_schema.call_method0("input_data_valid").unwrap(); let input = black_box(input); - validator.validate_python(py, input, None, None, None, None).unwrap(); + validator.validate_python(py, &input, None, None, None, None).unwrap(); bench.iter(|| { - black_box(validator.validate_python(py, input, None, None, None, None).unwrap()); + black_box(validator.validate_python(py, &input, None, None, None, None).unwrap()); }) }) } @@ -508,13 +505,13 @@ fn literal_ints_few_python(bench: &mut Bencher) { let validator = build_schema_validator(py, "{'type': 'literal', 'expected': list(range(5))}"); let input = 4_i64.into_py(py); - let input = input.as_ref(py); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = input.bind(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 4); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -524,13 +521,14 @@ fn literal_strings_few_small_python(bench: &mut Bencher) { let validator = build_schema_validator(py, "{'type': 'literal', 'expected': [f'{idx}' for idx in range(5)]}"); let input = py.eval("'4'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -543,20 +541,21 @@ fn literal_strings_few_large_python(bench: &mut Bencher) { ); let input = py.eval("'a' * 25 + '4'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } #[bench] fn literal_enums_few_python(bench: &mut Bencher) { Python::with_gil(|py| { - let globals = PyDict::new(py); + let globals = PyDict::new_bound(py); py.run( r#" from enum import Enum @@ -567,7 +566,7 @@ class Foo(Enum): v3 = object() v4 = object() "#, - Some(globals), + Some(globals.as_gil_ref()), None, ) .unwrap(); @@ -575,15 +574,16 @@ class Foo(Enum): let validator = build_schema_validator_with_globals( py, "{'type': 'literal', 'expected': [Foo.v1, Foo.v2, Foo.v3, Foo.v4]}", - Some(globals), + Some(globals.as_gil_ref()), ); - let input = py.eval("Foo.v4", Some(globals), None).unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = py.eval("Foo.v4", Some(globals.as_gil_ref()), None).unwrap(); + let input = input.to_object(py).into_bound(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); assert!(input.eq(result).unwrap()); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -592,14 +592,13 @@ fn literal_ints_many_python(bench: &mut Bencher) { Python::with_gil(|py| { let validator = build_schema_validator(py, "{'type': 'literal', 'expected': list(range(100))}"); - let input = 99_i64.into_py(py); - let input = input.as_ref(py); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = 99_i64.into_py(py).into_bound(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 99); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -609,13 +608,14 @@ fn literal_strings_many_small_python(bench: &mut Bencher) { let validator = build_schema_validator(py, "{'type': 'literal', 'expected': [f'{idx}' for idx in range(100)]}"); let input = py.eval("'99'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -628,13 +628,14 @@ fn literal_strings_many_large_python(bench: &mut Bencher) { ); let input = py.eval("'a' * 25 + '99'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) }) } @@ -644,12 +645,13 @@ fn literal_ints_many_json(bench: &mut Bencher) { let validator = build_schema_validator(py, "{'type': 'literal', 'expected': list(range(100))}"); let input_json = py.eval("'99'", None, None).unwrap(); - let result = validator.validate_json(py, input_json, None, None, None).unwrap(); + let input_json = input_json.to_object(py).into_bound(py); + let result = validator.validate_json(py, &input_json, None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, 99); let input_json = black_box(input_json); - bench.iter(|| black_box(validator.validate_json(py, input_json, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &input_json, None, None, None).unwrap())) }) } @@ -662,21 +664,23 @@ fn literal_strings_many_large_json(bench: &mut Bencher) { ); let input = py.eval("'a' * 25 + '99'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_json = py.eval("'\"' + 'a' * 25 + '99' + '\"'", None, None).unwrap(); + let input_json = input_json.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_json(py, input_json, None, None, None).unwrap(); + let result = validator.validate_json(py, &input_json, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input_json = black_box(input_json); - bench.iter(|| black_box(validator.validate_json(py, input_json, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_json(py, &input_json, None, None, None).unwrap())) }) } #[bench] fn literal_mixed_few_python(bench: &mut Bencher) { Python::with_gil(|py| { - let globals = PyDict::new(py); + let globals = PyDict::new_bound(py); py.run( r#" from enum import Enum @@ -687,58 +691,62 @@ class Foo(Enum): v3 = object() v4 = object() "#, - Some(globals), + Some(globals.as_gil_ref()), None, ) .unwrap(); let validator = build_schema_validator_with_globals( py, "{'type': 'literal', 'expected': [None, 'null', -1, Foo.v4]}", - Some(globals), + Some(globals.as_gil_ref()), ); // String { let input = py.eval("'null'", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_str: String = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_str: String = result.extract(py).unwrap(); assert_eq!(result_str, input_str); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) } // Int { let input = py.eval("-1", None, None).unwrap(); + let input = input.to_object(py).into_bound(py); let input_int: i64 = input.extract().unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); let result_int: i64 = result.extract(py).unwrap(); assert_eq!(result_int, input_int); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) } // None { let input = py.eval("None", None, None).unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = input.to_object(py).into_bound(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); assert!(input.eq(result).unwrap()); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) } // Enum { - let input = py.eval("Foo.v4", Some(globals), None).unwrap(); - let result = validator.validate_python(py, input, None, None, None, None).unwrap(); + let input = py.eval("Foo.v4", Some(globals.as_gil_ref()), None).unwrap(); + let input = input.to_object(py).into_bound(py); + let result = validator.validate_python(py, &input, None, None, None, None).unwrap(); assert!(input.eq(result).unwrap()); let input = black_box(input); - bench.iter(|| black_box(validator.validate_python(py, input, None, None, None, None).unwrap())) + bench.iter(|| black_box(validator.validate_python(py, &input, None, None, None, None).unwrap())) } }) } diff --git a/src/argument_markers.rs b/src/argument_markers.rs index 3e73fb629..748dd0ae2 100644 --- a/src/argument_markers.rs +++ b/src/argument_markers.rs @@ -15,9 +15,9 @@ pub struct ArgsKwargs { impl ArgsKwargs { fn eq(&self, py: Python, other: &Self) -> PyResult { - if self.args.as_ref(py).eq(other.args.as_ref(py))? { + if self.args.bind(py).eq(other.args.bind(py))? { match (&self.kwargs, &other.kwargs) { - (Some(d1), Some(d2)) => d1.as_ref(py).eq(d2.as_ref(py)), + (Some(d1), Some(d2)) => d1.bind(py).eq(d2.bind(py)), (None, None) => Ok(true), _ => Ok(false), } @@ -55,9 +55,9 @@ impl ArgsKwargs { } pub fn __repr__(&self, py: Python) -> String { - let args = safe_repr(self.args.as_ref(py)); + let args = safe_repr(self.args.bind(py)); match self.kwargs { - Some(ref d) => format!("ArgsKwargs({args}, {})", safe_repr(d.as_ref(py))), + Some(ref d) => format!("ArgsKwargs({args}, {})", safe_repr(d.bind(py))), None => format!("ArgsKwargs({args})"), } } diff --git a/src/build_tools.rs b/src/build_tools.rs index c242f97a3..23c88970b 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -12,10 +12,10 @@ use crate::tools::SchemaDict; use crate::ValidationError; pub fn schema_or_config<'py, T>( - schema: &'py PyDict, - config: Option<&'py PyDict>, - schema_key: &PyString, - config_key: &PyString, + schema: &Bound<'py, PyDict>, + config: Option<&Bound<'py, PyDict>>, + schema_key: &Bound<'py, PyString>, + config_key: &Bound<'py, PyString>, ) -> PyResult> where T: FromPyObject<'py>, @@ -30,9 +30,9 @@ where } pub fn schema_or_config_same<'py, T>( - schema: &'py PyDict, - config: Option<&'py PyDict>, - key: &PyString, + schema: &Bound<'py, PyDict>, + config: Option<&Bound<'py, PyDict>>, + key: &Bound<'py, PyString>, ) -> PyResult> where T: FromPyObject<'py>, @@ -40,7 +40,7 @@ where schema_or_config(schema, config, key, key) } -pub fn is_strict(schema: &PyDict, config: Option<&PyDict>) -> PyResult { +pub fn is_strict(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult { let py = schema.py(); Ok(schema_or_config_same(schema, config, intern!(py, "strict"))?.unwrap_or(false)) } @@ -90,7 +90,7 @@ impl SchemaError { ValidationError::new(line_errors, "Schema".to_object(py), InputType::Python, false); let schema_error = SchemaError(SchemaErrorEnum::ValidationError(validation_error)); match Py::new(py, schema_error) { - Ok(err) => PyErr::from_value(err.into_ref(py)), + Ok(err) => PyErr::from_value_bound(err.into_bound(py).into_any()), Err(err) => err, } } @@ -124,7 +124,7 @@ impl SchemaError { fn errors(&self, py: Python) -> PyResult> { match &self.0 { - SchemaErrorEnum::Message(_) => Ok(PyList::empty(py).into_py(py)), + SchemaErrorEnum::Message(_) => Ok(PyList::empty_bound(py).unbind()), SchemaErrorEnum::ValidationError(error) => error.errors(py, false, false, true), } } @@ -174,18 +174,18 @@ pub(crate) enum ExtraBehavior { impl ExtraBehavior { pub fn from_schema_or_config( py: Python, - schema: &PyDict, - config: Option<&PyDict>, + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, default: Self, ) -> PyResult { - let extra_behavior = schema_or_config::>( + let extra_behavior = schema_or_config::>>( schema, config, intern!(py, "extra_behavior"), intern!(py, "extra_fields_behavior"), )? .flatten(); - let res = match extra_behavior { + let res = match extra_behavior.as_ref().map(|s| s.to_str()).transpose()? { Some("allow") => Self::Allow, Some("ignore") => Self::Ignore, Some("forbid") => Self::Forbid, diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 2a48bfaf4..ecd429409 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -1,6 +1,7 @@ use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::PyDowncastError; +use pyo3::DowncastError; +use pyo3::DowncastIntoError; use jiter::JsonValue; @@ -35,8 +36,14 @@ impl From for ValError { } } -impl From> for ValError { - fn from(py_downcast: PyDowncastError) -> Self { +impl From> for ValError { + fn from(py_downcast: DowncastError) -> Self { + Self::InternalErr(PyTypeError::new_err(py_downcast.to_string())) + } +} + +impl From> for ValError { + fn from(py_downcast: DowncastIntoError) -> Self { Self::InternalErr(PyTypeError::new_err(py_downcast.to_string())) } } diff --git a/src/errors/location.rs b/src/errors/location.rs index 55bab0017..138d327ce 100644 --- a/src/errors/location.rs +++ b/src/errors/location.rs @@ -1,5 +1,5 @@ use pyo3::exceptions::PyTypeError; -use pyo3::once_cell::GILOnceCell; +use pyo3::sync::GILOnceCell; use std::fmt; use pyo3::prelude::*; @@ -126,9 +126,9 @@ static EMPTY_TUPLE: GILOnceCell = GILOnceCell::new(); impl ToPyObject for Location { fn to_object(&self, py: Python<'_>) -> PyObject { match self { - Self::List(loc) => PyTuple::new(py, loc.iter().rev()).to_object(py), + Self::List(loc) => PyTuple::new_bound(py, loc.iter().rev()).to_object(py), Self::Empty => EMPTY_TUPLE - .get_or_init(py, || PyTuple::empty(py).to_object(py)) + .get_or_init(py, || PyTuple::empty_bound(py).to_object(py)) .clone_ref(py), } } @@ -193,12 +193,12 @@ impl Serialize for Location { } } -impl TryFrom> for Location { +impl TryFrom>> for Location { type Error = PyErr; /// Only ever called by ValidationError -> PyLineError to convert user input to our internal Location /// Thus this expects the location to *not* be reversed and reverses it before storing it. - fn try_from(location: Option<&PyAny>) -> PyResult { + fn try_from(location: Option<&Bound<'_, PyAny>>) -> PyResult { if let Some(location) = location { let mut loc_vec: Vec = if let Ok(tuple) = location.downcast::() { tuple.iter().map(Into::into).collect() diff --git a/src/errors/mod.rs b/src/errors/mod.rs index de6650527..bee5f7225 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -13,8 +13,8 @@ pub use self::validation_exception::ValidationError; pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault}; pub fn py_err_string(py: Python, err: PyErr) -> String { - let value = err.value(py); - match value.get_type().name() { + let value = err.value_bound(py); + match value.get_type().qualname() { Ok(type_name) => match value.str() { Ok(py_str) => { let str_cow = py_str.to_string_lossy(); diff --git a/src/errors/types.rs b/src/errors/types.rs index eddd7dbaa..a77ed6a8e 100644 --- a/src/errors/types.rs +++ b/src/errors/types.rs @@ -3,8 +3,8 @@ use std::borrow::Cow; use std::fmt; use pyo3::exceptions::{PyKeyError, PyTypeError}; -use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::{PyDict, PyList}; use ahash::AHashMap; @@ -18,11 +18,11 @@ use crate::tools::{extract_i64, py_err, py_error_type}; use super::PydanticCustomError; #[pyfunction] -pub fn list_all_errors(py: Python) -> PyResult<&PyList> { - let mut errors: Vec<&PyDict> = Vec::with_capacity(100); +pub fn list_all_errors(py: Python) -> PyResult> { + let mut errors: Vec> = Vec::with_capacity(100); for error_type in ErrorType::iter() { if !matches!(error_type, ErrorType::CustomError { .. }) { - let d = PyDict::new(py); + let d = PyDict::new_bound(py); d.set_item("type", error_type.to_string())?; let message_template_python = error_type.message_template_python(); d.set_item("message_template_python", message_template_python)?; @@ -39,7 +39,7 @@ pub fn list_all_errors(py: Python) -> PyResult<&PyList> { errors.push(d); } } - Ok(PyList::new(py, errors)) + Ok(PyList::new_bound(py, errors)) } fn field_from_context<'py, T: FromPyObject<'py>>( @@ -121,7 +121,8 @@ macro_rules! error_types { } } - fn py_dict_update_ctx(&self, py: Python, dict: &PyDict) -> PyResult { + fn py_dict_update_ctx(&self, py: Python, dict: &Bound<'_, PyDict>) -> PyResult { + use pyo3::types::PyMapping; match self { $( Self::$item { context, $($key,)* } => { @@ -129,7 +130,7 @@ macro_rules! error_types { dict.set_item::<&str, Py>(stringify!($key), $key.to_object(py))?; )* if let Some(ctx) = context { - dict.update(ctx.as_ref(py).downcast()?)?; + dict.update(ctx.bind(py).downcast::()?)?; Ok(true) } else { Ok(false) @@ -669,20 +670,20 @@ impl ErrorType { Self::ValueError { error, .. } => { let error = &error .as_ref() - .map_or(Cow::Borrowed("None"), |v| Cow::Owned(v.as_ref(py).to_string())); + .map_or(Cow::Borrowed("None"), |v| Cow::Owned(v.bind(py).to_string())); render!(tmpl, error) } Self::AssertionError { error, .. } => { let error = &error .as_ref() - .map_or(Cow::Borrowed("None"), |v| Cow::Owned(v.as_ref(py).to_string())); + .map_or(Cow::Borrowed("None"), |v| Cow::Owned(v.bind(py).to_string())); render!(tmpl, error) } Self::CustomError { message_template, context, .. - } => PydanticCustomError::format_message(message_template, context.as_ref().map(|c| c.as_ref(py))), + } => PydanticCustomError::format_message(message_template, context.as_ref().map(|c| c.bind(py))), Self::LiteralError { expected, .. } => render!(tmpl, expected), Self::DateParsing { error, .. } => render!(tmpl, error), Self::DateFromDatetimeParsing { error, .. } => render!(tmpl, error), @@ -729,8 +730,8 @@ impl ErrorType { } pub fn py_dict(&self, py: Python) -> PyResult>> { - let dict = PyDict::new(py); - let custom_ctx_used = self.py_dict_update_ctx(py, dict)?; + let dict = PyDict::new_bound(py); + let custom_ctx_used = self.py_dict_update_ctx(py, &dict)?; if let Self::CustomError { .. } = self { if custom_ctx_used { @@ -785,7 +786,7 @@ impl From for Number { } impl FromPyObject<'_> for Number { - fn extract(obj: &PyAny) -> PyResult { + fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult { if let Some(int) = extract_i64(obj) { Ok(Number::Int(int)) } else if let Ok(float) = obj.extract::() { diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index e87f4b0d9..059763a21 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -5,8 +5,8 @@ use std::str::from_utf8; use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; use pyo3::ffi; use pyo3::intern; -use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::{PyDict, PyList, PyString}; use serde::ser::{Error, SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -73,7 +73,7 @@ impl ValidationError { return cause_problem; } } - PyErr::from_value(err.as_ref(py)) + PyErr::from_value_bound(err.into_bound(py).into_any()) } Err(err) => err, } @@ -202,9 +202,9 @@ fn include_url_env(py: Python) -> bool { match std::env::var_os("PYDANTIC_ERRORS_OMIT_URL") { Some(val) => { // We don't care whether warning succeeded or not, hence the assignment - let _ = PyErr::warn( + let _ = PyErr::warn_bound( py, - py.get_type::(), + &py.get_type_bound::(), "PYDANTIC_ERRORS_OMIT_URL is deprecated, use PYDANTIC_ERRORS_INCLUDE_URL instead", 1, ); @@ -253,14 +253,17 @@ impl ValidationError { fn from_exception_data( py: Python, title: PyObject, - line_errors: &PyList, + line_errors: Bound<'_, PyList>, input_type: &str, hide_input: bool, ) -> PyResult> { Py::new( py, Self { - line_errors: line_errors.iter().map(PyLineError::try_from).collect::>()?, + line_errors: line_errors + .iter() + .map(|error| PyLineError::try_from(&error)) + .collect::>()?, title, input_type: InputType::try_from(input_type)?, hide_input, @@ -287,7 +290,7 @@ impl ValidationError { ) -> PyResult> { let url_prefix = get_url_prefix(py, include_url); let mut iteration_error = None; - let list = PyList::new( + let list = PyList::new_bound( py, // PyList::new takes ExactSizeIterator, so if an error occurs during iteration we // fill the list with None before returning the error; the list will then be thrown @@ -318,7 +321,7 @@ impl ValidationError { include_url: bool, include_context: bool, include_input: bool, - ) -> PyResult<&'py PyString> { + ) -> PyResult> { let state = SerializationState::new("iso8601", "utf8", "constants")?; let extra = state.extra(py, &SerMode::Json, true, false, false, true, None, None); let serializer = ValidationErrorSerializer { @@ -347,7 +350,7 @@ impl ValidationError { } }; let s = from_utf8(&bytes).map_err(json_py_err)?; - Ok(PyString::new(py, s)) + Ok(PyString::new_bound(py, s)) } fn __repr__(&self, py: Python) -> String { @@ -358,12 +361,12 @@ impl ValidationError { self.__repr__(py) } - fn __reduce__(slf: &PyCell) -> PyResult<(&PyAny, PyObject)> { + fn __reduce__<'py>(slf: &Bound<'py, Self>) -> PyResult<(Bound<'py, PyAny>, PyObject)> { let py = slf.py(); let callable = slf.getattr("from_exception_data")?; let borrow = slf.try_borrow()?; let args = ( - borrow.title.as_ref(py), + borrow.title.bind(py), borrow.errors(py, include_url_env(py), true, true)?, borrow.input_type.into_py(py), borrow.hide_input, @@ -463,11 +466,11 @@ impl From for ValLineError { } } -impl TryFrom<&PyAny> for PyLineError { +impl TryFrom<&Bound<'_, PyAny>> for PyLineError { type Error = PyErr; - fn try_from(value: &PyAny) -> PyResult { - let dict: &PyDict = value.downcast()?; + fn try_from(value: &Bound<'_, PyAny>) -> PyResult { + let dict = value.downcast::()?; let py = value.py(); let type_raw = dict @@ -485,7 +488,7 @@ impl TryFrom<&PyAny> for PyLineError { )); }; - let location = Location::try_from(dict.get_item("loc")?)?; + let location = Location::try_from(dict.get_item("loc")?.as_ref())?; let input_value = match dict.get_item("input")? { Some(i) => i.into_py(py), @@ -513,7 +516,7 @@ impl PyLineError { input_type: InputType, include_input: bool, ) -> PyResult { - let dict = PyDict::new(py); + let dict = PyDict::new_bound(py); dict.set_item("type", self.error_type.type_string())?; dict.set_item("loc", self.location.to_object(py))?; dict.set_item("msg", self.error_type.render_message(py, input_type)?)?; @@ -555,11 +558,11 @@ impl PyLineError { write!(output, " {message} [type={}", self.error_type.type_string())?; if !hide_input { - let input_value = self.input_value.as_ref(py); + let input_value = self.input_value.bind(py); let input_str = safe_repr(input_value); - truncate_input_value!(output, &input_str); + truncate_input_value!(output, &input_str.to_cow()); - if let Ok(type_) = input_value.get_type().name() { + if let Ok(type_) = input_value.get_type().qualname() { write!(output, ", input_type={type_}")?; } } @@ -663,13 +666,13 @@ impl<'py> Serialize for PyLineErrorSerializer<'py> { if self.include_input { map.serialize_entry( "input", - &self.extra.serialize_infer(self.line_error.input_value.as_ref(py)), + &self.extra.serialize_infer(self.line_error.input_value.bind(py)), )?; } if self.include_context { if let Some(context) = self.line_error.error_type.py_dict(py).map_err(py_err_json::)? { - map.serialize_entry("ctx", &self.extra.serialize_infer(context.as_ref(py)))?; + map.serialize_entry("ctx", &self.extra.serialize_infer(context.bind(py)))?; } } if let Some(url_prefix) = self.url_prefix { diff --git a/src/errors/value_exception.rs b/src/errors/value_exception.rs index 3152044d1..e69fc0a9c 100644 --- a/src/errors/value_exception.rs +++ b/src/errors/value_exception.rs @@ -89,7 +89,7 @@ impl PydanticCustomError { } pub fn message(&self, py: Python) -> PyResult { - Self::format_message(&self.message_template, self.context.as_ref().map(|c| c.as_ref(py))) + Self::format_message(&self.message_template, self.context.as_ref().map(|c| c.bind(py))) } fn __str__(&self, py: Python) -> PyResult { @@ -99,7 +99,7 @@ impl PydanticCustomError { fn __repr__(&self, py: Python) -> PyResult { let msg = self.message(py)?; match self.context.as_ref() { - Some(ctx) => Ok(format!("{msg} [type={}, context={}]", self.error_type, ctx.as_ref(py))), + Some(ctx) => Ok(format!("{msg} [type={}, context={}]", self.error_type, ctx.bind(py))), None => Ok(format!("{msg} [type={}, context=None]", self.error_type)), } } @@ -115,14 +115,14 @@ impl PydanticCustomError { ValError::new(error_type, input) } - pub fn format_message(message_template: &str, context: Option<&PyDict>) -> PyResult { + pub fn format_message(message_template: &str, context: Option<&Bound<'_, PyDict>>) -> PyResult { let mut message = message_template.to_string(); if let Some(ctx) = context { - for (key, value) in ctx { - let key: &PyString = key.downcast()?; + for (key, value) in ctx.iter() { + let key = key.downcast::()?; if let Ok(py_str) = value.downcast::() { message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?); - } else if let Some(value_int) = extract_i64(value) { + } else if let Some(value_int) = extract_i64(&value) { message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string()); } else { // fallback for anything else just in case @@ -174,11 +174,7 @@ impl PydanticKnownError { fn __repr__(&self, py: Python) -> PyResult { let msg = self.message(py)?; match self.context(py)?.as_ref() { - Some(ctx) => Ok(format!( - "{msg} [type={}, context={}]", - self.error_type(), - ctx.as_ref(py) - )), + Some(ctx) => Ok(format!("{msg} [type={}, context={}]", self.error_type(), ctx.bind(py))), None => Ok(format!("{msg} [type={}, context=None]", self.error_type())), } } diff --git a/src/input/datetime.rs b/src/input/datetime.rs index ebb5675f2..9dcf9006e 100644 --- a/src/input/datetime.rs +++ b/src/input/datetime.rs @@ -20,7 +20,7 @@ use crate::tools::py_err; #[cfg_attr(debug_assertions, derive(Debug))] pub enum EitherDate<'a> { Raw(Date), - Py(&'a PyDate), + Py(Bound<'a, PyDate>), } impl<'a> From for EitherDate<'a> { @@ -29,13 +29,13 @@ impl<'a> From for EitherDate<'a> { } } -impl<'a> From<&'a PyDate> for EitherDate<'a> { - fn from(date: &'a PyDate) -> Self { +impl<'a> From> for EitherDate<'a> { + fn from(date: Bound<'a, PyDate>) -> Self { Self::Py(date) } } -pub fn pydate_as_date(py_date: &PyAny) -> PyResult { +pub fn pydate_as_date(py_date: &Bound<'_, PyAny>) -> PyResult { let py = py_date.py(); Ok(Date { year: py_date.getattr(intern!(py, "year"))?.extract()?, @@ -55,7 +55,7 @@ impl<'a> EitherDate<'a> { pub fn try_into_py(self, py: Python<'_>) -> PyResult { let date = match self { Self::Py(date) => Ok(date), - Self::Raw(date) => PyDate::new(py, date.year.into(), date.month, date.day), + Self::Raw(date) => PyDate::new_bound(py, date.year.into(), date.month, date.day), }?; Ok(date.into_py(py)) } @@ -64,7 +64,7 @@ impl<'a> EitherDate<'a> { #[cfg_attr(debug_assertions, derive(Debug))] pub enum EitherTime<'a> { Raw(Time), - Py(&'a PyTime), + Py(Bound<'a, PyTime>), } impl<'a> From