diff --git a/serde/src/private/de.rs b/serde/src/private/de.rs index 883e6909c..607dbeea0 100644 --- a/serde/src/private/de.rs +++ b/serde/src/private/de.rs @@ -206,6 +206,7 @@ mod content { // This issue is tracking making some of this stuff public: // https://github.com/serde-rs/serde/issues/741 + use crate::de::VariantAccess; use crate::lib::*; use crate::actually_private; @@ -249,6 +250,7 @@ mod content { Newtype(Box>), Seq(Vec>), Map(Vec<(Content<'de>, Content<'de>)>), + Enum(Box<(Content<'de>, Content<'de>)>), } impl<'de> Content<'de> { @@ -286,6 +288,7 @@ mod content { Content::Newtype(_) => Unexpected::NewtypeStruct, Content::Seq(_) => Unexpected::Seq, Content::Map(_) => Unexpected::Map, + Content::Enum(_) => Unexpected::Enum, } } } @@ -510,13 +513,13 @@ mod content { Ok(Content::Map(vec)) } - fn visit_enum(self, _visitor: V) -> Result + fn visit_enum(self, visitor: V) -> Result where V: EnumAccess<'de>, { - Err(de::Error::custom( - "untagged and internally tagged enums do not support enum input", - )) + let (variant, access) = tri!(visitor.variant()); + let value = tri!(access.newtype_variant()); + Ok(Content::Enum(Box::new((variant, value)))) } } @@ -1146,6 +1149,7 @@ mod content { Content::Newtype(v) => visitor.visit_newtype_struct(ContentDeserializer::new(*v)), Content::Seq(v) => visit_content_seq(v, visitor), Content::Map(v) => visit_content_map(v, visitor), + Content::Enum(v) => visitor.visit_enum(EnumDeserializer::new(v.0, Some(v.1))), } } @@ -1747,6 +1751,9 @@ mod content { } Content::Seq(ref v) => visit_content_seq_ref(v, visitor), Content::Map(ref v) => visit_content_map_ref(v, visitor), + Content::Enum(ref v) => { + visitor.visit_enum(EnumRefDeserializer::new(&v.0, Some(&v.1))) + } } } @@ -2091,6 +2098,19 @@ mod content { err: PhantomData, } + impl<'a, 'de: 'a, E> EnumRefDeserializer<'a, 'de, E> + where + E: de::Error, + { + fn new(variant: &'a Content<'de>, value: Option<&'a Content<'de>>) -> Self { + Self { + variant, + value, + err: PhantomData, + } + } + } + impl<'de, 'a, E> de::EnumAccess<'de> for EnumRefDeserializer<'a, 'de, E> where E: de::Error,