Skip to content

Commit

Permalink
Speed up deserialization by using exceptions for error propagation ra…
Browse files Browse the repository at this point in the history
…ther than the result monad.

 * User facing API remains unchanged
 * This change speeds up deserialization by a factor of ~2
  • Loading branch information
andersfugmann committed Dec 30, 2023
1 parent b96ab35 commit 86dd890
Show file tree
Hide file tree
Showing 15 changed files with 407 additions and 290 deletions.
140 changes: 66 additions & 74 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
(** Module for deserializing values *)

open StdLabels
open Result

module S = Spec.Deserialize
module C = S.C
open S

type 'a sentinal = unit -> 'a Result.t
type 'a decoder = Field.t -> 'a Result.t
type 'a sentinal = unit -> 'a
type 'a decoder = Field.t -> 'a

type (_, _) sentinal_list =
| SNil : ('a, 'a) sentinal_list
| SCons : ('a sentinal) * ('b, 'c) sentinal_list -> ('a -> 'b, 'c) sentinal_list

let error_wrong_field str field : _ Result.t =
`Wrong_field_type (str, field) |> Result.fail

let error_illegal_value str field : _ Result.t = `Illegal_value (str, field) |> Result.fail
let error_required_field_missing: _ Result.t = `Required_field_missing |> Result.fail
let error_wrong_field str field = Result.raise (`Wrong_field_type (str, field))
let error_illegal_value str field = Result.raise (`Illegal_value (str, field))
let error_required_field_missing () = Result.raise `Required_field_missing

let read_varint ~signed ~type_name =
let open! Infix.Int64 in
Expand All @@ -29,26 +26,26 @@ let read_varint ~signed ~type_name =
| true -> (v / 2L * -1L) - 1L
| false -> v
in
return v
v
end
| field -> error_wrong_field type_name field

let read_varint32 ~signed ~type_name field =
read_varint ~signed ~type_name field >>| Int64.to_int32
read_varint ~signed ~type_name field |> Int64.to_int32

let rec type_of_spec: type a. a spec -> 'b * a decoder =
let int_of_int32 spec =
let (tpe, f) = type_of_spec spec in
let f field =
f field >>| Int32.to_int
f field |> Int32.to_int
in
(tpe, f)
in

let int_of_uint32 spec =
let (tpe, f) = type_of_spec spec in
let f field =
f field >>| (fun v ->
f field |> (fun v ->
match Sys.word_size with
| 32 ->
(* If the high bit is set, we cannot represent it anyways *)
Expand All @@ -65,7 +62,7 @@ let rec type_of_spec: type a. a spec -> 'b * a decoder =
let int_of_int64 spec =
let (tpe, f) = type_of_spec spec in
let f field =
f field >>| Int64.to_int
f field |> Int64.to_int
in
(tpe, f)
in
Expand All @@ -74,17 +71,17 @@ let rec type_of_spec: type a. a spec -> 'b * a decoder =
let (tpe, f) = type_of_spec spec in
let f field =
(* If high-bit is set, we cannot represent it *)
f field >>| Int64.to_int
f field |> Int64.to_int
in
(tpe, f)
in

function
| Double -> (`Fixed_64_bit, function
| Field.Fixed_64_bit v -> return (Int64.float_of_bits v)
| Field.Fixed_64_bit v -> Int64.float_of_bits v
| field -> error_wrong_field "double" field)
| Float -> (`Fixed_32_bit, function
| Field.Fixed_32_bit v -> return (Int32.float_of_bits v)
| Field.Fixed_32_bit v -> Int32.float_of_bits v
| field -> error_wrong_field "float" field)
| Int32 -> (`Varint, read_varint32 ~signed:false ~type_name:"int32")
| Int32_int -> int_of_int32 Int32
Expand All @@ -99,33 +96,33 @@ let rec type_of_spec: type a. a spec -> 'b * a decoder =
| SInt64 -> (`Varint, read_varint ~signed:true ~type_name:"sint64")
| SInt64_int -> int_of_int64 SInt64
| Fixed32 -> (`Fixed_32_bit, function
| Field.Fixed_32_bit v -> return (v)
| Field.Fixed_32_bit v -> v
| field -> error_wrong_field "fixed32" field)
| Fixed32_int -> int_of_int32 Fixed32
| Fixed64 -> (`Fixed_64_bit, function
| Field.Fixed_64_bit v -> return v
| Field.Fixed_64_bit v -> v
| field -> error_wrong_field "fixed64" field)
| Fixed64_int -> int_of_int64 Fixed64

| SFixed32 -> (`Fixed_32_bit, function
| Field.Fixed_32_bit v -> return v
| Field.Fixed_32_bit v -> v
| field -> error_wrong_field "sfixed32" field)
| SFixed32_int -> int_of_int32 SFixed32
| SFixed64 -> (`Fixed_64_bit, function
| Field.Fixed_64_bit v -> return v
| Field.Fixed_64_bit v -> v
| field -> error_wrong_field "sfixed64" field)
| SFixed64_int -> int_of_int64 SFixed64
| Bool -> (`Varint, function
| Field.Varint v -> return (Int64.equal v 0L |> not)
| Field.Varint v -> Int64.equal v 0L |> not
| field -> error_wrong_field "bool" field)
| Enum of_int -> (`Varint, function
| Field.Varint v -> of_int (Int64.to_int v)
| field -> error_wrong_field "enum" field)
| String -> (`Length_delimited, function
| Field.Length_delimited {offset; length; data} -> return (String.sub ~pos:offset ~len:length data)
| Field.Length_delimited {offset; length; data} -> String.sub ~pos:offset ~len:length data
| field -> error_wrong_field "string" field)
| Bytes -> (`Length_delimited, function
| Field.Length_delimited {offset; length; data} -> return (String.sub ~pos:offset ~len:length data |> Bytes.of_string)
| Field.Length_delimited {offset; length; data} -> String.sub ~pos:offset ~len:length data |> Bytes.of_string
| field -> error_wrong_field "string" field)
| Message from_proto -> (`Length_delimited, function
| Field.Length_delimited {offset; length; data} -> from_proto (Reader.create ~offset ~length data)
Expand All @@ -142,25 +139,25 @@ let sentinal: type a. a compound -> (int * unit decoder) list * a sentinal = fun
| Basic (index, (Message deser), _) ->
let v = ref None in
let get () = match !v with
| None -> error_required_field_missing
| Some v -> return v
| None -> error_required_field_missing ()
| Some v -> v
in
let read = function
| Field.Length_delimited {offset; length; data} ->
let reader = Reader.create ~length ~offset data in
deser reader >>| fun message -> v := Some message
deser reader |> fun message -> v := Some message
| field -> error_wrong_field "message" field
in
([index, read], get)
| Basic (index, spec, Required) ->
let _, read = type_of_spec spec in
let v = ref None in
let get () = match !v with
| Some v -> return v
| None -> error_required_field_missing
| Some v -> v
| None -> error_required_field_missing ()
in
let read field =
read field >>| fun value -> v := Some value
read field |> fun value -> v := Some value
in
([index, read], get)
| Basic (index, spec, default) ->
Expand All @@ -170,24 +167,21 @@ let sentinal: type a. a compound -> (int * unit decoder) list * a sentinal = fun
| Required
| Proto3 -> begin
default_of_field_type field_type
|> read
|> function
| Ok v -> v
| Error _ -> failwith "Cannot decode default field value"
|> fun v -> try read v with Result.Error _ -> failwith "Cannot decode default field value"
end
in
let v = ref default in
let get () = return !v in
let get () = !v in
let read field =
read field >>| fun value -> v := value
read field |> fun value -> v := value
in
([index, read], get)
| Basic_opt (index, spec) ->
let _, read = type_of_spec spec in
let v = ref None in
let get () = return !v in
let get () = !v in
let read field =
read field >>| fun value -> v := Some value
read field |> fun value -> v := Some value
in
([index, read], get)
| Repeated (index, spec, _) ->
Expand All @@ -198,33 +192,33 @@ let sentinal: type a. a compound -> (int * unit decoder) list * a sentinal = fun
| `Fixed_32_bit -> Some Reader.read_fixed32
in
let rec read_repeated reader decode read_f = match Reader.has_more reader with
| false -> return ()
| false -> ()
| true ->
decode reader >>= fun field ->
read_f field >>= fun () ->
decode reader |> fun field ->
read_f field |> fun () ->
read_repeated reader decode read_f
in
let (field_type, read_type) = type_of_spec spec in
let v = ref [] in
let get () = return (List.rev !v) in
let get () = List.rev !v in
let rec read field = match field, read_field field_type with
| (Field.Length_delimited _ as field), None ->
read_type field >>| fun v' -> v := v' :: !v
read_type field |> fun v' -> v := v' :: !v
| Field.Length_delimited { offset; length; data }, Some read_field ->
read_repeated (Reader.create ~offset ~length data) read_field read
| field, _ -> read_type field >>| fun v' -> v := v' :: !v
| field, _ -> read_type field |> fun v' -> v := v' :: !v
in
([index, read], get)
| Oneof oneofs ->
let make_reader: a ref -> a oneof -> (int * unit decoder) = fun v (Oneof_elem (index, spec, constr)) ->
let _, read = type_of_spec spec in
let read field =
read field >>| fun value -> v := (constr value)
read field |> fun value -> v := (constr value)
in
(index, read)
in
let v = ref `not_set in
let get () = return !v in
let get () = !v in
List.map ~f:(make_reader v) oneofs, get

module Map = struct
Expand All @@ -246,21 +240,19 @@ let read_fields_map extension_ranges reader_list =
let map = Map.of_alist_exn reader_list in
let rec read reader =
match Reader.has_more reader with
| false -> return (List.rev !extensions)
| true -> begin
match Reader.read_field reader with
| Ok (index, field) -> begin
match Map.find_opt index map with
| Some f ->
f field >>= fun () ->
read reader
| None when in_extension_ranges extension_ranges index ->
extensions := (index, field) :: !extensions;
read reader
| None ->
read reader
end
| Error err -> Error err
| false -> List.rev !extensions
| true ->
begin
let (index, field) = Reader.read_field reader in
match Map.find_opt index map with
| Some f ->
f field |> fun () ->
read reader
| None when in_extension_ranges extension_ranges index ->
extensions := (index, field) :: !extensions;
read reader
| None ->
read reader
end
in
read
Expand All @@ -271,30 +263,30 @@ let read_fields_array extension_ranges max_index reader_list =
let default index field =
match in_extension_ranges extension_ranges index with
| true -> extensions := (index, field) :: !extensions;
return ()
()
| false ->
return ()
()
in
let readers = Array.init (max_index + 1) ~f:(fun _ -> default) in
List.iter ~f:(fun (idx, f) -> readers.(idx) <- (fun _ -> f)) reader_list;

let rec read reader =
match Reader.has_more reader with
| false -> return (List.rev !extensions)
| false -> List.rev !extensions
| true -> begin
match Reader.read_field reader with
| Ok (index, field) when index <= max_index ->
readers.(index) index field >>= fun () ->
let (index, field) = Reader.read_field reader in
match index <= max_index with
| true ->
readers.(index) index field |> fun () ->
read reader
| Ok (index, field) ->
default index field >>= fun () ->
| false ->
default index field |> fun () ->
read reader
| Error err -> Error err
end
in
read

let deserialize: type constr t. (int * int) list -> (constr, t) compound_list -> ((int * Field.t) list -> constr) -> Reader.t -> t Result.t = fun extension_ranges spec constr ->
let deserialize: type constr t. (int * int) list -> (constr, t) compound_list -> ((int * Field.t) list -> constr) -> Reader.t -> t = fun extension_ranges spec constr ->
let max_index =
let rec inner: type a b. int -> (a, b) compound_list -> int = fun acc -> function
| Cons (Oneof oneofs, rest) ->
Expand Down Expand Up @@ -332,10 +324,10 @@ let deserialize: type constr t. (int * int) list -> (constr, t) compound_list ->
| true -> read_fields_array extension_ranges max_index
| false -> read_fields_map extension_ranges
in
let rec apply: type constr t. constr -> (constr, t) sentinal_list -> t Result.t = fun constr -> function
let rec apply: type constr t. constr -> (constr, t) sentinal_list -> t = fun constr -> function
| SCons (sentinal, rest) ->
sentinal () >>= fun v -> apply (constr v) rest
| SNil -> return constr
sentinal () |> fun v -> apply (constr v) rest
| SNil -> constr
in
(* We first make a list of sentinal_getters, which we can map to the constr *)
let rec make_sentinals: type a b. (a, b) compound_list -> (a, b) sentinal_list * (int * unit decoder) list = function
Expand All @@ -348,4 +340,4 @@ let deserialize: type constr t. (int * int) list -> (constr, t) compound_list ->
fun reader ->
let sentinals, reader_list = make_sentinals spec in
(* Read the fields one by one, and apply the reader - if found *)
read_fields reader_list reader >>= fun extensions -> apply (constr extensions) sentinals
read_fields reader_list reader |> fun extensions -> apply (constr extensions) sentinals
5 changes: 3 additions & 2 deletions src/ocaml_protoc_plugin/extensions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ let show : t -> string = Format.asprintf "%a" pp
let equal _ _ = true
let compare _ _ = 0

let get: ('b -> 'b, 'b) Deserialize.S.compound_list -> t -> 'b Result.t = fun spec t ->
let get: ('b -> 'b, 'b) Deserialize.S.compound_list -> t -> 'b = fun spec t ->
let writer = Writer.of_list t in
(* Back and forth - its the same, no? *)
let reader = Writer.contents writer |> Reader.create in
Expand All @@ -16,7 +16,8 @@ let get: ('b -> 'b, 'b) Deserialize.S.compound_list -> t -> 'b Result.t = fun sp
let set: ('a -> Writer.t, Writer.t) Serialize.S.compound_list -> t -> 'a -> t = fun spec t v ->
let writer = Serialize.serialize [] spec [] v in
let reader = Writer.contents writer |> Reader.create in
match Reader.to_list reader |> Result.get ~msg:"Internal serialization fail" with
match Reader.to_list reader with
| (((index, _) :: _) as fields) ->
(List.filter ~f:(fun (i, _) -> i != index) t) @ fields
| [] -> t
| exception Result.Error _ -> failwith "Internal serialization fail"
2 changes: 1 addition & 1 deletion src/ocaml_protoc_plugin/extensions.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ val pp : Format.formatter -> t -> unit
val show : t -> string
val equal : t -> t -> bool
val compare : t -> t -> int
val get : ('b -> 'b, 'b) Deserialize.S.compound_list -> t -> 'b Result.t
val get : ('b -> 'b, 'b) Deserialize.S.compound_list -> t -> 'b
val set : ('a -> Writer.t, Writer.t) Spec.Serialize.compound_list -> t -> 'a -> t
Loading

0 comments on commit 86dd890

Please sign in to comment.