8000 Support manually specifying case labels for union validators by dmontagu · Pull Request #841 · pydantic/pydantic-core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Support manually specifying case labels for union validators #841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ def type_dict_schema(typed_dict) -> dict[str, Any]: # noqa: C901
schema = {'type': &# 8000 39;dict', 'keys_schema': {'type': 'str'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'Dict[Hashable, CoreSchema]':
schema = {'type': 'dict', 'keys_schema': {'type': 'any'}, 'values_schema': schema_ref_validator}
elif fr_arg == 'List[Union[CoreSchema, Tuple[CoreSchema, str]]]':
schema = {
'type': 'list',
'items_schema': {
'type': 'union',
'choices': [
schema_ref_validator,
{'type': 'tuple-positional', 'items_schema': [schema_ref_validator, {'type': 'str'}]},
],
},
}
else:
raise ValueError(f'Unknown Schema forward ref: {fr_arg}')
else:
Expand Down
8 changes: 4 additions & 4 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union

if sys.version_info < (3, 11):
from typing_extensions import Protocol, Required, TypeAlias
Expand Down Expand Up @@ -2454,7 +2454,7 @@ def nullable_schema(

class UnionSchema(TypedDict, total=False):
type: Required[Literal['union']]
choices: Required[List[CoreSchema]]
choices: Required[List[Union[CoreSchema, Tuple[CoreSchema, str]]]]
# default true, whether to automatically collapse unions with one element to the inner validator
auto_collapse: bool
custom_error_type: str
Expand All @@ -2467,7 +2467,7 @@ class UnionSchema(TypedDict, total=False):


def union_schema(
choices: list[CoreSchema],
choices: list[CoreSchema | tuple[CoreSchema, str]],
*,
auto_collapse: bool | None = None,
custom_error_type: str | None = None,
Expand All @@ -2491,7 +2491,7 @@ def union_schema(
```

Args:
choices: The schemas to match
choices: The schemas to match. If a tuple, the second item is used as the label for the case.
auto_collapse: whether to automatically collapse unions with one element to the inner validator, default true
custom_error_type: The custom error type to use if the validation fails
custom_error_message: The custom error message to use if the validation fails
Expand Down
10 changes: 8 additions & 2 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::types::{PyDict, PyList, PyTuple};
use std::borrow::Cow;

use crate::build_tools::py_schema_err;
Expand Down Expand Up @@ -31,7 +31,13 @@ impl BuildSerializer for UnionSerializer {
let choices: Vec<CombinedSerializer> = schema
.get_as_req::<&PyList>(intern!(py, "choices"))?
.iter()
.map(|choice| CombinedSerializer::build(choice.downcast()?, config, definitions))
.map(|choice| {
let choice: &PyAny = match choice.downcast::<PyTuple>() {
Ok(py_tuple) => py_tuple.get_item(0)?,
Err(_) => choice,
};
CombinedSerializer::build(choice.downcast()?, config, definitions)
})
.collect::<PyResult<Vec<CombinedSerializer>>>()?;

Self::from_choices(choices)
Expand Down
70 changes: 46 additions & 24 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Write;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::types::{PyDict, PyList, PyString, PyTuple};
use pyo3::{intern, PyTraverseError, PyVisit};

use crate::build_tools::py_schema_err;
Expand All @@ -19,7 +19,7 @@ use super::{build_validator, BuildValidator, CombinedValidator, Definitions, Def

#[derive(Debug, Clone)]
pub struct UnionValidator {
choices: Vec<CombinedValidator>,
choices: Vec<(CombinedValidator, Option<String>)>,
custom_error: Option<CustomError>,
strict: bool,
name: String,
Expand All @@ -36,18 +36,33 @@ impl BuildValidator for UnionValidator {
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let py = schema.py();
let choices: Vec<CombinedValidator> = schema
let choices: Vec<(CombinedValidator, Option<String>)> = schema
.get_as_req::<&PyList>(intern!(py, "choices"))?
.iter()
.map(|choice| build_validator(choice, config, definitions))
.collect::<PyResult<Vec<CombinedValidator>>>()?;
.map(|choice| {
let mut label: Option<String> = None;
let choice: &PyAny = match choice.downcast::<PyTuple>() {
Ok(py_tuple) => {
let choice = py_tuple.get_item(0)?;
label = Some(py_tuple.get_item(1)?.to_string());
choice
}
Err(_) => choice,
};
Ok((build_validator(choice, config, definitions)?, label))
})
.collect::<PyResult<Vec<(CombinedValidator, Option<String>)>>>()?;

let auto_collapse = || schema.get_as_req(intern!(py, "auto_collapse")).unwrap_or(true);
match choices.len() {
0 => py_schema_err!("One or more union choices required"),
1 if auto_collapse() => Ok(choices.into_iter().next().unwrap()),
1 if auto_collapse() => Ok(choices.into_iter().next().unwrap().0),
_ => {
let descr = choices.iter().map(Validator::get_name).collect::<Vec<_>>().join(",");
let descr = choices
.iter()
.map(|(choice, label)| label.as_deref().unwrap_or(choice.get_name()))
.collect::<Vec<_>>()
.join(",");

Ok(Self {
choices,
Expand Down Expand Up @@ -77,7 +92,12 @@ impl UnionValidator {
}
}

impl_py_gc_traverse!(UnionValidator { choices });
impl PyGcTraverse for UnionValidator {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
self.choices.iter().try_for_each(|(v, _)| v.py_gc_traverse(visit))?;
Ok(())
}
}

impl Validator for UnionValidator {
fn validate<'s, 'data>(
Expand All @@ -94,7 +114,9 @@ impl Validator for UnionValidator {
if let Some(res) = self
.choices
.iter()
.map(|validator| validator.validate(py, input, &ultra_strict_extra, definitions, recursion_guard))
.map(|(validator, _label)| {
validator.validate(py, input, &ultra_strict_extra, definitions, recursion_guard)
})
.find(ValResult::is_ok)
{
return res;
Expand All @@ -108,18 +130,17 @@ impl Validator for UnionValidator {
};
let strict_extra = extra.as_strict(false);

for validator in &self.choices {
for (validator, label) in &self.choices {
let line_errors = match validator.validate(py, input, &strict_extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
otherwise => return otherwise,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(case_label.into())
}));
}
}

Expand All @@ -132,7 +153,9 @@ impl Validator for UnionValidator {
if let Some(res) = self
.choices
.iter()
.map(|validator| validator.validate(py, input, &strict_extra, definitions, recursion_guard))
.map(|(validator, _label)| {
validator.validate(py, input, &strict_extra, definitions, recursion_guard)
})
.find(ValResult::is_ok)
{
return res;
Expand All @@ -145,18 +168,17 @@ impl Validator for UnionValidator {
};

// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
for validator in &self.choices {
for (validator, label) in &self.choices {
let line_errors = match validator.validate(py, input, extra, definitions, recursion_guard) {
Err(ValError::LineErrors(line_errors)) => line_errors,
success => return success,
};

if let Some(ref mut errors) = errors {
errors.extend(
line_errors
.into_iter()
.map(|err| err.with_outer_location(validator.get_name().into())),
);
errors.extend(line_errors.into_iter().map(|err| {
let case_label = label.as_deref().unwrap_or(validator.get_name());
err.with_outer_location(case_label.into())
}));
}
}

Expand All @@ -171,15 +193,15 @@ impl Validator for UnionValidator {
) -> bool {
self.choices
.iter()
.any(|v| v.different_strict_behavior(definitions, ultra_strict))
.any(|(v, _)| v.different_strict_behavior(definitions, ultra_strict))
}

fn get_name(&self) -> &str {
&self.name
}

fn complete(&mut self, definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
self.choices.iter_mut().try_for_each(|v| v.complete(definitions))?;
self.choices.iter_mut().try_for_each(|(v, _)| v.complete(definitions))?;
self.strict_required = self.different_strict_behavior(Some(definitions), false);
self.ultra_strict_required = self.different_strict_behavior(Some(definitions), true);
Ok(())
Expand Down
9 changes: 7 additions & 2 deletions tests/serializers/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ def __init__(self, **kwargs) -> None:
setattr(self, name, value)


@pytest.mark.parametrize('bool_case_label', [False, True])
@pytest.mark.parametrize('int_case_label', [False, True])
@pytest.mark.parametrize('input_value,expected_value', [(True, True), (False, False), (1, 1), (123, 123), (-42, -42)])
def test_union_bool_int(input_value, expected_value):
s = SchemaSerializer(core_schema.union_schema([core_schema.bool_schema(), core_schema.int_schema()]))
def test_union_bool_int(input_value, expected_value, bool_case_label, int_case_label):
bool_case = core_schema.bool_schema() if not bool_case_label else (core_schema.bool_schema(), 'my_bool_label')
int_case = core_schema.int_schema() if not int_case_label else (core_schema.int_schema(), 'my_int_label')
s = SchemaSerializer(core_schema.union_schema([bool_case, int_case]))

assert s.to_python(input_value) == expected_value
assert s.to_python(input_value, mode='json') == expected_value
assert s.to_json(input_value) == json.dumps(expected_value).encode()
Expand Down
20 changes: 20 additions & 0 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,23 @@ def test_strict_reference():

assert repr(v.validate_python((1, 2))) == '(1.0, 2)'
assert repr(v.validate_python((1.0, (2.0, 3)))) == '(1.0, (2.0, 3))'


def test_case_labels():
v = SchemaValidator(
{'type': 'union', 'choices': [{'type': 'none'}, ({'type': 'int'}, 'my_label'), {'type': 'str'}]}
)
assert v.validate_python(None) is None
assert v.validate_python(1) == 1
with pytest.raises(ValidationError, match=r'3 validation errors for union\[none,my_label,str]') as exc_info:
v.validate_python(1.5)
assert exc_info.value.errors(include_url=False) == [
{'input': 1.5, 'loc': ('none',), 'msg': 'Input should be None', 'type': 'none_required'},
{
'input': 1.5,
'loc': ('my_label',),
'msg': 'Input should be a valid integer, got a number with a fractional part',
'type': 'int_from_float',
},
{'input': 1.5, 'loc': ('str',), 'msg': 'Input should be a valid string', 'type': 'string_type'},
]
0