Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andersfugmann committed Jan 29, 2024
1 parent 3d800ed commit 87175e3
Showing 1 changed file with 30 additions and 42 deletions.
72 changes: 30 additions & 42 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ module S = Spec.Deserialize
module C = S.C
open S

type required = Required | Optional

type 'a reader = 'a -> Reader.t -> Field.field_type -> 'a
type ('a, 'b) getter = 'a -> 'b
type 'a field_spec = (int * 'a reader)
type _ value = Value: ('b field_spec list * required * 'b * ('b, 'a) getter) -> 'a value
type _ value = Value: ('b field_spec list * 'b * ('b, 'a) getter) -> 'a value
type extensions = (int * Field.t) list

type (_, _) value_list =
Expand Down Expand Up @@ -88,7 +86,7 @@ let read_of_spec: type a. a spec -> Field.field_type * (Reader.t -> a) = functio
| Message (from_proto, _merge) -> Length_delimited, fun reader ->
let Field.{ offset; length; data } = Reader.read_length_delimited reader in
from_proto (Reader.create ~offset ~length data)

(*
let default_value: type a. a spec -> a = function
| Double -> 0.0
| Float -> 0.0
Expand Down Expand Up @@ -117,7 +115,7 @@ let default_value: type a. a spec -> a = function
| SFixed64_int -> 0
| Enum of_int -> of_int 0
| Bool -> false

*)
let id x = x
let keep_last _ v = v

Expand All @@ -136,21 +134,20 @@ let value: type a. a compound -> a value = function
| Some v1 -> Some (merge v1 v2)
in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> failwith "Get called on unset required field" in
Value ([(index, read)], Required, None, getter)
| Basic (index, spec, default) ->
let getter = function Some v -> v | None -> error_required_field_missing () in
Value ([(index, read)], None, getter)
| Basic (index, spec, None) ->
Printf.eprintf "Really no default for index %d\n" index;
(* I think we need to create a new _req type for proto2 required fields (they would not have a default, since they are required *)
let map _ v2 = Some v2 in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> error_required_field_missing () in
Value ([(index, read)], None, getter)
| Basic (index, spec, Some default) ->
let map = keep_last
in
let read = read_field ~read:(read_of_spec spec) ~map in
let required = match default with
| Some _ -> Optional
| None -> Required
in
let default = match default with
| None -> default_value spec
| Some default -> default
in
Value ([(index, read)], required, default, id)
Value ([(index, read)], default, id)
| Basic_opt (index, spec) ->
let map = match spec with
| Message (_, merge) ->
Expand All @@ -163,7 +160,7 @@ let value: type a. a compound -> a value = function
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
Value ([(index, read)], Optional, None, id)
Value ([(index, read)], None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
let rec read_packed_values read_f acc reader =
Expand All @@ -182,16 +179,16 @@ let value: type a. a compound -> a value = function
let field = Reader.read_field_content ft reader in
error_wrong_field "Deserialize" field
in
Value ([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Repeated (index, spec, Not_packed) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun vs v -> v :: vs) in
Value ([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Oneof oneofs ->
let make_reader: a oneof -> a field_spec = fun (Oneof_elem (index, spec, constr)) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ -> constr) in
(index, read)
in
Value (List.map ~f:make_reader oneofs, Optional, `not_set, id)
Value (List.map ~f:make_reader oneofs, `not_set, id)

module IntMap = Map.Make(struct type t = int let compare = Int.compare end)

Expand All @@ -204,15 +201,12 @@ let deserialize_full: type constr a. extension_ranges -> (constr, a) value_list
| VNil -> NNil
| VNil_ext -> NNil_ext
(* Consider optimizing when optional is true *)
| VCons (Value (fields, required, default, getter), rest) ->
let v = ref (default, required) in
let get () = match !v with
| _, Required -> error_required_field_missing ();
| v, Optional-> getter v
in
| VCons (Value (fields, default, getter), rest) ->
let v = ref default in
let get () = getter !v in
let fields =
List.map ~f:(fun (index, read) ->
let read reader field_type = let v' = fst !v in v := (read v' reader field_type, Optional) in
let read reader field_type = (v := read !v reader field_type) in
(index, read)
) fields
in
Expand Down Expand Up @@ -297,34 +291,27 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
| VNil_ext when idx = Int.max_int ->
constr (List.rev extensions)
(* All fields read successfully. Apply extensions and return result. *)
| VCons (Value ([index, read_f], _required, default, get), vs) when index = idx ->
| VCons (Value ([index, read_f], default, get), vs) when index = idx ->
(* Read all values, and apply constructor once all fields have been read.
This pattern is the most likely to be matched for all values, and is added
as an optimization to avoid reconstructing the value list for each recursion.
*)
let default, tpe, idx = read_repeated tpe index read_f default reader in
let constr = (constr (get default)) in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (Value ((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
| VCons (Value ((index, read_f) :: fields, default, get), vs) when index = idx ->
(* Read all values for the given field *)
let default, tpe, idx = read_repeated tpe index read_f default reader in
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, Optional, default, get), vs))
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| vs when in_extension_ranges extension_ranges idx ->
(* Extensions may be sent inline. Store all valid extensions, before starting to apply constructors *)
let extensions = (idx, Reader.read_field_content tpe reader) :: extensions in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (Value ([], Required, _default, _get), _vs) ->
(* If there are no more fields to be read we will never find the value.
If all values are read, then raise, else revert to full deserialization *)
begin match (idx = Int.max_int) with
| true -> error_required_field_missing ()
| false -> raise Restart_full
end
| VCons (Value (_ :: fields, optional, default, get), vs) ->
| VCons (Value (_ :: fields, default, get), vs) ->
(* Drop the field, as we dont expect to find it. *)
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, optional, default, get), vs))
| VCons (Value ([], Optional, default, get), vs) ->
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| VCons (Value ([], default, get), vs) ->
(* Apply destructor. This case is only relevant for oneof fields *)
read_values extension_ranges tpe idx reader (constr (get default)) extensions vs
| VNil | VNil_ext ->
Expand All @@ -342,6 +329,7 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
let (tpe, idx) = next_field reader in
try
read_values extension_ranges tpe idx reader constr [] values
with Restart_full ->
with (Restart_full | Result.Error `Required_field_missing) ->
(* Revert to full deserialization *)
Reader.reset reader offset;
deserialize_full extension_ranges values constr reader

0 comments on commit 87175e3

Please sign in to comment.