Skip to content

Commit

Permalink
Merge pull request #53 from issuu/andersfugmann/all_in_one
Browse files Browse the repository at this point in the history
Fix bug in name lookup leading to unusable code generated.
  • Loading branch information
andersfugmann authored Jan 7, 2024
2 parents ba82d22 + 970466b commit 6cb91b6
Show file tree
Hide file tree
Showing 6 changed files with 838 additions and 151 deletions.
127 changes: 77 additions & 50 deletions src/plugin/emit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ let emit_enum_type ~scope ~params
{module_name; signature; implementation}

let emit_service_type ~options scope ServiceDescriptorProto.{ name; method' = methods; _ } =
let emit_method t local_scope scope service_name MethodDescriptorProto.{ name; input_type; output_type; _} =
let emit_method signature implementation local_scope scope service_name MethodDescriptorProto.{ name; input_type; output_type; _} =
let name = Option.value_exn name in
let mangle_f = match Scope.has_mangle_option options with
| false -> fun id -> id
Expand All @@ -77,32 +77,42 @@ let emit_service_type ~options scope ServiceDescriptorProto.{ name; method' = me
let input_t = Scope.get_scoped_name scope ~postfix:"t" input_type in
let output = Scope.get_scoped_name scope output_type in
let output_t = Scope.get_scoped_name scope ~postfix:"t" output_type in
Code.emit t `Begin "module %s = struct" capitalized_name;
Code.emit t `None "let package_name = %s" (Option.value_map ~default:"None" ~f:(fun n -> sprintf "Some \"%s\"" n) package_name_opt);
Code.emit t `None "let service_name = \"%s\"" service_name;
Code.emit t `None "let method_name = \"%s\"" name;
Code.emit t `None "let name = \"/%s%s/%s\"" package_name service_name name;
Code.emit t `None "module Request = %s" input;
Code.emit t `None "module Response = %s" output;
Code.emit t `End "end";
Code.emit t `Begin "let %s = " uncapitalized_name;
Code.emit t `None "(module %s : Runtime'.Service.Message with type t = %s ), "
let sig_t = sprintf "Runtime'.Service.Rpc with type Request.t = %s and type Response.t = %s" input_t output_t in
Code.emit implementation `Begin "module %s : %s = struct" capitalized_name sig_t;
Code.emit implementation `None "let package_name = %s" (Option.value_map ~default:"None" ~f:(fun n -> sprintf "Some \"%s\"" n) package_name_opt);
Code.emit implementation `None "let service_name = \"%s\"" service_name;
Code.emit implementation `None "let method_name = \"%s\"" name;
Code.emit implementation `None "let name = \"/%s%s/%s\"" package_name service_name name;
Code.emit implementation `None "module Request = %s" input;
Code.emit implementation `None "module Response = %s" output;
Code.emit implementation `End "end";
let sig_t' =
sprintf "(module Runtime'.Service.Message with type t = %s) * (module Runtime'.Service.Message with type t = %s)" input_t output_t
in
Code.emit implementation `Begin "let %s : %s = " uncapitalized_name sig_t';
Code.emit implementation `None "(module %s : Runtime'.Service.Message with type t = %s ), "
input
input_t;
Code.emit t `None "(module %s : Runtime'.Service.Message with type t = %s )"
Code.emit implementation `None "(module %s : Runtime'.Service.Message with type t = %s )"
output
output_t;
Code.emit t `End "";
Code.emit implementation `End "";

Code.emit signature `None "module %s : %s" capitalized_name sig_t;
Code.emit signature `None "val %s : %s" uncapitalized_name sig_t';
()
in
let name = Option.value_exn ~message:"Service definitions must have a name" name in
let t = Code.init () in
Code.emit t `Begin "module %s = struct" (Scope.get_name scope name);
let signature = Code.init () in
let implementation = Code.init () in
Code.emit signature `Begin "module %s : sig" (Scope.get_name scope name);
Code.emit implementation `Begin "module %s = struct" (Scope.get_name scope name);
let local_scope = Scope.Local.init () in

List.iter ~f:(emit_method t local_scope (Scope.push scope name) name) methods;
Code.emit t `End "end";
t
List.iter ~f:(emit_method signature implementation local_scope (Scope.push scope name) name) methods;
Code.emit signature `End "end";
Code.emit implementation `End "end";
signature, implementation

let emit_extension ~scope ~params field =
let FieldDescriptorProto.{ name; extendee; _ } = field in
Expand Down Expand Up @@ -239,40 +249,37 @@ let rec emit_message ~params ~syntax scope

let rec wrap_packages ~params ~syntax ~options scope message_type services = function
| [] ->
let {module_name = _; implementation; _} = emit_message ~params ~syntax scope message_type in
let { module_name = _; implementation; signature } = emit_message ~params ~syntax scope message_type in
List.iter ~f:(fun service ->
Code.append implementation (emit_service_type ~options scope service)
) services;
implementation
let signature', implementation' = emit_service_type ~options scope service in
Code.append implementation implementation';
Code.append signature signature';
()
) services;
signature, implementation

| package :: packages ->
let signature = Code.init () in
let implementation = Code.init () in
let package_name = Scope.get_name scope package in
let scope = Scope.push scope package in
Code.emit implementation `Begin "module %s = struct" package_name;
Code.append implementation (wrap_packages ~params ~syntax ~options scope message_type services packages);

let signature', implementation' =
wrap_packages ~params ~syntax ~options scope message_type services packages
in

Code.emit implementation `Begin "module rec %s : sig" package_name;
Code.append implementation signature';
Code.emit implementation `EndBegin "end = struct";
Code.append implementation implementation';
Code.emit implementation `End "end";
implementation
Code.emit signature `Begin "module rec %s : sig" package_name;
Code.append signature signature';
Code.emit signature `End "end";

let parse_proto_file ~params scope
FileDescriptorProto.{ name; package; dependency = dependencies; public_dependency = _;
weak_dependency = _; message_type = message_types;
enum_type = enum_types; service = services; extension;
options; source_code_info = _; syntax; }
=
let name = Option.value_exn ~message:"All files must have a name" name |> String.map ~f:(function '-' -> '_' | c -> c) in
let syntax = match syntax with
| None | Some "proto2" -> `Proto2
| Some "proto3" -> `Proto3
| _ -> failwith "Unsupported syntax"
in
let message_type =
DescriptorProto.{name = None; nested_type=message_types; enum_type = enum_types;
field = []; extension; extension_range = []; oneof_decl = [];
options = None; reserved_range = []; reserved_name = []; }
in
let implementation = Code.init () in
signature, implementation

let emit_header implementation ~name ~syntax ~params =
Code.emit implementation `None "(************************************************)";
Code.emit implementation `None "(* AUTOGENERATED FILE - DO NOT EDIT! *)";
Code.emit implementation `None "(************************************************)";
Expand All @@ -292,6 +299,27 @@ let parse_proto_file ~params scope
Code.emit implementation `None " singleton_record=%b" params.singleton_record;
Code.emit implementation `None "*)";
Code.emit implementation `None "";
()

let parse_proto_file ~params scope
FileDescriptorProto.{ name; package; dependency = dependencies; public_dependency = _;
weak_dependency = _; message_type = message_types;
enum_type = enum_types; service = services; extension;
options; source_code_info = _; syntax; }
=
let name = Option.value_exn ~message:"All files must have a name" name |> String.map ~f:(function '-' -> '_' | c -> c) in
let syntax = match syntax with
| None | Some "proto2" -> `Proto2
| Some "proto3" -> `Proto3
| _ -> failwith "Unsupported syntax"
in
let message_type =
DescriptorProto.{name = None; nested_type=message_types; enum_type = enum_types;
field = []; extension; extension_range = []; oneof_decl = [];
options = None; reserved_range = []; reserved_name = []; }
in
let implementation = Code.init () in
emit_header implementation ~name ~syntax ~params;
Code.emit implementation `None "open Ocaml_protoc_plugin.Runtime [@@warning \"-33\"]";
List.iter ~f:(Code.emit implementation `None "open %s [@@warning \"-33\"]" ) params.opens;
let _ = match dependencies with
Expand All @@ -306,12 +334,11 @@ let parse_proto_file ~params scope
Code.emit implementation `End "end";
Code.emit implementation `None "(**/**)";
in
wrap_packages ~params ~syntax ~options scope message_type services (Option.value_map ~default:[] ~f:(String.split_on_char ~sep:'.') package)
|> Code.append implementation;
let _signature', implementation' =
wrap_packages ~params ~syntax ~options scope message_type services (Option.value_map ~default:[] ~f:(String.split_on_char ~sep:'.') package)
in

Code.append implementation implementation';

let out_name =
Filename.remove_extension name
|> sprintf "%s.ml"
in
out_name, implementation
let base_name = Filename.remove_extension name in
(base_name ^ ".ml"), implementation
135 changes: 108 additions & 27 deletions src/plugin/scope.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
open StdLabels
open MoreLabels

let failwith_f fmt =
Printf.ksprintf (fun s -> failwith s) fmt

let dump_tree = false
let dump_ocaml_names = false

module StringMap = struct
include Map.Make(String)
Expand All @@ -10,11 +14,12 @@ module StringMap = struct
let add_uniq ~key ~data map =
update ~key ~f:(function
| None -> Some data
| Some _ -> failwith (Printf.sprintf "Key %s already exists" key)
| Some _ -> failwith_f "Key %s already exists" key
) map
end
module StringSet = Set.Make(String)


(** Module to avoid name clashes in a local scope *)
module Local = struct
type t = (string, unit) Hashtbl.t
Expand Down Expand Up @@ -301,69 +306,145 @@ end

type t = { module_name: string;
package_depth: int;
proto_path: string;
proto_path: string list;
type_db: element StringMap.t;
ocaml_names: StringSet.t;
}

let dump_type_map type_map =
Printf.eprintf "Type map:\n";
StringMap.iter ~f:(fun ~key ~data:{module_name; ocaml_name; cyclic; _ } ->
Printf.eprintf " %s -> %s#%s, C:%b\n%!" key module_name ocaml_name cyclic
) type_map;
Printf.eprintf "Type map end:\n%!"
Printf.eprintf "Type map end.\n%!"

let init files =
let type_db = Type_tree.create_db files in
let ocaml_names =
StringMap.fold ~init:StringSet.empty
~f:(fun ~key:_ ~data:{ocaml_name; _} acc ->
StringSet.add ocaml_name acc
) type_db
in
if dump_tree then dump_type_map type_db;
{ module_name = ""; proto_path = ""; package_depth = 0; type_db; }
if dump_ocaml_names then
StringSet.iter ~f:(Printf.eprintf "%s\n") ocaml_names;


{ module_name = ""; proto_path = []; package_depth = 0; type_db; ocaml_names}

let for_descriptor t FileDescriptorProto.{ name; package; _ } =
let name = Option.value_exn ~message:"All file descriptors must have a name" name in
let module_name = module_name_of_proto name in
let package_depth = Option.value_map ~default:0 ~f:(fun p -> String.split_on_char ~sep:'.' p |> List.length) package in
{ t with package_depth; module_name; proto_path = "" }
{ t with package_depth; module_name; proto_path = [] }

let push: t -> string -> t = fun t name -> { t with proto_path = t.proto_path ^ "." ^ name }
let get_proto_path t =
"" :: (List.rev t.proto_path) |> String.concat ~sep:"."

let rec drop n = function
| [] -> []
| _ :: xs when n > 0 -> drop (n - 1) xs
| xs -> xs
let push: t -> string -> t = fun t name -> { t with proto_path = name :: t.proto_path }

let get_scoped_name ?postfix t name =
let name = Option.value_exn ~message:"Does not contain a name" name in
(* Take the first n elements from the list *)
let take n l =
let rec inner = function
| (0, _) -> []
| (_, []) -> []
| (n, x :: xs) -> x :: inner (n - 1, xs)
in
inner (n, l)
in

(* Resolve name in the current context and return the fully qualified module name,
iff exists *)
let resolve t name =
let rec lookup name = function
| path ->
begin
let path_str = String.concat ~sep:"." (name :: path |> List.rev) in
match StringSet.mem path_str t.ocaml_names with
| false -> begin
match path with
| [] -> None
| _ :: ps -> lookup name ps
end
| true -> Some path_str
end
in
let { ocaml_name = ocaml_path; _ } =
StringMap.find (get_proto_path t) t.type_db
in
let path = match String.equal "" ocaml_path with
| false -> String.split_on_char ~sep:'.' ocaml_path |> List.rev
| true -> []
in
lookup name path
in

let name = Option.value_exn ~message:"Does not contain a name" name in
let { ocaml_name; module_name; _ } = StringMap.find name t.type_db in
let type_name = match String.equal module_name t.module_name with

(* Lookup a fully qualified name in the current scope.
Returns the shortest name for the type in the current scope *)
let rec lookup postfix_length = function
| p :: ps ->
begin
let expect = String.concat ~sep:"." (List.rev (p :: ps)) in
let resolve_res = resolve t p in
match resolve_res with
| Some path when String.equal path expect ->
let how_many = postfix_length in
let ocaml_name =
String.split_on_char ~sep:'.' ocaml_name
|> List.rev
|> take how_many
|> List.rev
|> String.concat ~sep:"."
in
ocaml_name
| _ ->
lookup (postfix_length + 1) ps
end
| [] ->
failwith_f "Unable to reference '%s'. This is due to a limitation in the Ocaml mappings. To work around this limitation make sure to use a unique package name" name
in
let type_name =
match String.equal module_name t.module_name with
| true ->
ocaml_name
|> String.split_on_char ~sep:'.'
|> drop t.package_depth
|> String.concat ~sep:"."
| false -> Printf.sprintf "%s.%s.%s" import_module_name module_name ocaml_name
let names =
String.split_on_char ~sep:'.' ocaml_name
|> List.rev
in
lookup 1 names
| false ->
Printf.sprintf "%s.%s.%s" import_module_name module_name ocaml_name
in
(* Strip away the package depth *)
Option.value_map ~default:type_name ~f:(fun postfix -> type_name ^ "." ^ postfix) postfix

match postfix, type_name with
| Some postfix, "" -> postfix
| None, "" -> failwith "Empty type cannot be referenced"
| None, type_name -> type_name
| Some postfix, type_name -> Printf.sprintf "%s.%s" type_name postfix

let get_name t name =
let path = t.proto_path ^ "." ^ name in
let path = Printf.sprintf "%s.%s" (get_proto_path t) name in
match StringMap.find_opt path t.type_db with
| Some { ocaml_name; _ } -> String.split_on_char ~sep:'.' ocaml_name |> List.rev |> List.hd
| None -> failwith (Printf.sprintf "Cannot find %s in %s." name t.proto_path)
| None -> failwith_f "Cannot find '%s' in '%s'." name (get_proto_path t)

let get_name_exn t name =
let name = Option.value_exn ~message:"Does not contain a name" name in
get_name t name

let get_current_scope t =
let { module_name; ocaml_name = _; _ } = StringMap.find t.proto_path t.type_db in
(String.lowercase_ascii module_name) ^ t.proto_path
let { module_name; ocaml_name = _; _ } = StringMap.find (get_proto_path t) t.type_db in
(String.lowercase_ascii module_name) ^ (get_proto_path t)

let get_package_name { proto_path; _ } =
match String.split_on_char ~sep:'.' proto_path with
| _ :: xs -> List.rev xs |> List.tl |> List.rev |> String.concat ~sep:"." |> Option.some
let get_package_name t =
match t.proto_path with
| _ :: xs -> List.rev xs |> String.concat ~sep:"." |> Option.some
| _ -> None

let is_cyclic t =
let { cyclic; _ } = StringMap.find t.proto_path t.type_db in
let { cyclic; _ } = StringMap.find (get_proto_path t) t.type_db in
cyclic
3 changes: 3 additions & 0 deletions src/plugin/scope.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ val is_cyclic: t -> bool

(** Test is the options specify name mangling *)
val has_mangle_option: Spec.Descriptor.Google.Protobuf.FileOptions.t option -> bool

(** Get stringified version of the current proto path *)
val get_proto_path: t -> string
Loading

0 comments on commit 6cb91b6

Please sign in to comment.