Skip to content

Commit

Permalink
Merge pull request #3228 from softwaremill/multi-customise-schemas
Browse files Browse the repository at this point in the history
Allow usage-site customisation of referenced schemas
  • Loading branch information
adamw authored Oct 9, 2023
2 parents 2184b0d + 1f4c358 commit 79717e1
Show file tree
Hide file tree
Showing 30 changed files with 404 additions and 178 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/sttp/tapir/attribute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ object AttributeKey extends AttributeKeyMacros
case class AttributeMap private (private val storage: Map[String, Any]) {
def get[T](k: AttributeKey[T]): Option[T] = storage.get(k.typeName).asInstanceOf[Option[T]]
def put[T](k: AttributeKey[T], v: T): AttributeMap = copy(storage = storage + (k.typeName -> v))
def remove[T](k: AttributeKey[T]): AttributeMap = copy(storage = storage - k.typeName)

def isEmpty: Boolean = storage.isEmpty
def nonEmpty: Boolean = storage.nonEmpty
Expand Down
10 changes: 5 additions & 5 deletions doc/docs/json-schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ You can conveniently generate JSON schema from Tapir schema, which can be derive
Schema generation can now be performed like in the following example:

```scala mdoc:compile-only
import sttp.apispec.{ReferenceOr, Schema => ASchema}
import sttp.apispec.{Schema => ASchema}
import sttp.tapir._
import sttp.tapir.docs.apispec.schema._
import sttp.tapir.generic.auto._
Expand All @@ -21,7 +21,7 @@ import sttp.tapir.generic.auto._
case class Child(childName: String) // to illustrate unique name generation
val tSchema = implicitly[Schema[Parent]]

val jsonSchema: ReferenceOr[ASchema] = TapirSchemaToJsonSchema(
val jsonSchema: ASchema = TapirSchemaToJsonSchema(
tSchema,
markOptionsAsNullable = true,
metaSchema = MetaSchemaDraft04 // default
Expand All @@ -44,7 +44,7 @@ you will get a codec for `sttp.apispec.Schema`:
import io.circe.Printer
import io.circe.syntax._
import sttp.apispec.circe._
import sttp.apispec.{ReferenceOr, Schema => ASchema, SchemaType => ASchemaType}
import sttp.apispec.{Schema => ASchema, SchemaType => ASchemaType}
import sttp.tapir._
import sttp.tapir.docs.apispec.schema._
import sttp.tapir.generic.auto._
Expand All @@ -57,12 +57,12 @@ import sttp.tapir.Schema.annotations.title
case class Child(childName: String)
val tSchema = implicitly[Schema[Parent]]

val jsonSchema: ReferenceOr[ASchema] = TapirSchemaToJsonSchema(
val jsonSchema: ASchema = TapirSchemaToJsonSchema(
tSchema,
markOptionsAsNullable = true)

// JSON serialization
val schemaAsJson = jsonSchema.getOrElse(ASchema(ASchemaType.Null)).asJson
val schemaAsJson = jsonSchema.asJson
val schemaStr: String = Printer.spaces2.print(schemaAsJson.deepDropNullValues)
```

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sttp.tapir.docs.apispec.schema

import sttp.apispec.{ReferenceOr, SchemaType, Schema => ASchema}
import sttp.apispec.{SchemaType, Schema => ASchema}
import sttp.tapir.{Codec, Schema => TSchema, SchemaType => TSchemaType}

/** Converts a tapir schema to an OpenAPI/AsyncAPI reference (if the schema is named), or to the appropriate schema. */
Expand All @@ -9,16 +9,18 @@ class Schemas(
toSchemaReference: ToSchemaReference,
markOptionsAsNullable: Boolean
) {
def apply[T](codec: Codec[T, _, _]): ReferenceOr[ASchema] = apply(codec.schema)
def apply[T](codec: Codec[T, _, _]): ASchema = apply(codec.schema)

def apply(schema: TSchema[_]): ReferenceOr[ASchema] = {
SchemaKey(schema) match {
case Some(key) => Left(toSchemaReference.map(key))
def apply(schema: TSchema[_]): ASchema = {
schema.name match {
case Some(name) => toSchemaReference.map(schema, name)
case None =>
schema.schemaType match {
case TSchemaType.SArray(nested @ TSchema(_, Some(name), isOptional, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(toSchemaReference.map(SchemaKey(nested, name))))))
.map(s => if (isOptional && markOptionsAsNullable) s.copy(nullable = Some(true)) else s)
val s = ASchema(SchemaType.Array)
.copy(items = Some(toSchemaReference.map(nested, name)))

if (isOptional && markOptionsAsNullable) s.copy(nullable = Some(true)) else s
case TSchemaType.SOption(ts) => apply(ts)
case _ => tschemaToASchema(schema)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sttp.tapir.docs.apispec.schema

import sttp.apispec.{Schema => ASchema, _}
import sttp.apispec.{Schema => ASchema}
import sttp.tapir.Schema.SName
import sttp.tapir._
import sttp.tapir.internal.IterableToListMap
Expand All @@ -10,36 +10,41 @@ import scala.collection.immutable.ListMap
class SchemasForEndpoints(
es: Iterable[AnyEndpoint],
schemaName: SName => String,
toKeyedSchemas: ToKeyedSchemas,
markOptionsAsNullable: Boolean,
additionalOutputs: List[EndpointOutput[_]]
) {

def apply(): (ListMap[SchemaId, ReferenceOr[ASchema]], Schemas) = {
val keyedSchemas = ToKeyedSchemas.unique(
/** @return
* A tuple: the first element can be used to create the components section in the docs. The second can be used to resolve (possible)
* top-level references from parameters / bodies.
*/
def apply(): (ListMap[SchemaId, ASchema], Schemas) = {
val keyedCombinedSchemas: Iterable[KeyedSchema] = ToKeyedSchemas.uniqueCombined(
es.flatMap(e =>
forInput(e.securityInput) ++ forInput(e.input) ++ forOutput(e.errorOutput) ++ forOutput(e.output)
) ++ additionalOutputs.flatMap(forOutput(_))
)
val keysToIds = calculateUniqueIds(keyedSchemas.map(_._1), (key: SchemaKey) => schemaName(key.name))
val keysToIds: Map[SchemaKey, SchemaId] = calculateUniqueIds(keyedCombinedSchemas.map(_._1), (key: SchemaKey) => schemaName(key.name))

val toSchemaReference = new ToSchemaReference(keysToIds)
val toSchemaReference = new ToSchemaReference(keysToIds, keyedCombinedSchemas.toMap)
val tschemaToASchema = new TSchemaToASchema(toSchemaReference, markOptionsAsNullable)

val keysToSchemas: ListMap[SchemaKey, ASchema] = keyedCombinedSchemas.map(td => (td._1, tschemaToASchema(td._2))).toListMap
val schemaIds: Map[SchemaKey, (SchemaId, ASchema)] = keysToSchemas.map { case (k, v) => k -> ((keysToIds(k), v)) }

val schemas = new Schemas(tschemaToASchema, toSchemaReference, markOptionsAsNullable)
val keysToSchemas = keyedSchemas.map(td => (td._1, tschemaToASchema(td._2))).toListMap

val schemaIds = keysToSchemas.map { case (k, v) => k -> ((keysToIds(k), v)) }
(schemaIds.values.toListMap, schemas)
}

private def forInput(input: EndpointInput[_]): List[KeyedSchema] = {
input match {
case EndpointInput.FixedMethod(_, _, _) => List.empty
case EndpointInput.FixedPath(_, _, _) => List.empty
case EndpointInput.PathCapture(_, codec, _) => toKeyedSchemas(codec)
case EndpointInput.PathCapture(_, codec, _) => ToKeyedSchemas(codec)
case EndpointInput.PathsCapture(_, _) => List.empty
case EndpointInput.Query(_, _, codec, _) => toKeyedSchemas(codec)
case EndpointInput.Cookie(_, codec, _) => toKeyedSchemas(codec)
case EndpointInput.Query(_, _, codec, _) => ToKeyedSchemas(codec)
case EndpointInput.Cookie(_, codec, _) => ToKeyedSchemas(codec)
case EndpointInput.QueryParams(_, _) => List.empty
case _: EndpointInput.Auth[_, _] => List.empty
case _: EndpointInput.ExtractFromRequest[_] => List.empty
Expand All @@ -48,6 +53,7 @@ class SchemasForEndpoints(
case op: EndpointIO[_] => forIO(op)
}
}

private def forOutput(output: EndpointOutput[_]): List[KeyedSchema] = {
output match {
case EndpointOutput.OneOf(variants, _) => variants.flatMap(variant => forOutput(variant.output)).toList
Expand All @@ -57,19 +63,19 @@ class SchemasForEndpoints(
case EndpointOutput.Void() => List.empty
case EndpointOutput.Pair(left, right, _, _) => forOutput(left) ++ forOutput(right)
case EndpointOutput.WebSocketBodyWrapper(wrapped) =>
toKeyedSchemas(wrapped.codec) ++ toKeyedSchemas(wrapped.requests) ++ toKeyedSchemas(wrapped.responses)
ToKeyedSchemas(wrapped.codec) ++ ToKeyedSchemas(wrapped.requests) ++ ToKeyedSchemas(wrapped.responses)
case op: EndpointIO[_] => forIO(op)
}
}

private def forIO(io: EndpointIO[_]): List[KeyedSchema] = {
io match {
case EndpointIO.Pair(left, right, _, _) => forIO(left) ++ forIO(right)
case EndpointIO.Header(_, codec, _) => toKeyedSchemas(codec)
case EndpointIO.Header(_, codec, _) => ToKeyedSchemas(codec)
case EndpointIO.Headers(_, _) => List.empty
case EndpointIO.Body(_, codec, _) => toKeyedSchemas(codec)
case EndpointIO.Body(_, codec, _) => ToKeyedSchemas(codec)
case EndpointIO.OneOfBody(variants, _) => variants.flatMap(v => forIO(v.bodyAsAtom))
case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, _, _)) => toKeyedSchemas(codec.schema)
case EndpointIO.StreamBodyWrapper(StreamBodyIO(_, codec, _, _, _)) => ToKeyedSchemas(codec.schema)
case EndpointIO.MappedPair(wrapped, _) => forIO(wrapped)
case EndpointIO.FixedHeader(_, _, _) => List.empty
case EndpointIO.Empty(_, _) => List.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,56 @@ import sttp.apispec.{Schema => ASchema, _}
import sttp.tapir.Schema.Title
import sttp.tapir.Validator.EncodeToRaw
import sttp.tapir.docs.apispec.DocsExtensionAttribute.RichSchema
import sttp.tapir.docs.apispec.schema.TSchemaToASchema.{tDefaultToADefault, tExampleToAExample}
import sttp.tapir.docs.apispec.{DocsExtensions, exampleValue}
import sttp.tapir.internal.{IterableToListMap, _}
import sttp.tapir.internal._
import sttp.tapir.{Validator, Schema => TSchema, SchemaType => TSchemaType}

/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using the given map to resolve nested references. */
/** Converts a tapir schema to an OpenAPI/AsyncAPI schema, using `toSchemaReference` to resolve nested references. */
private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, markOptionsAsNullable: Boolean) {
def apply[T](schema: TSchema[T], isOptionElement: Boolean = false): ReferenceOr[ASchema] = {
def apply[T](schema: TSchema[T], isOptionElement: Boolean = false): ASchema = {
val nullable = markOptionsAsNullable && isOptionElement
val result = schema.schemaType match {
case TSchemaType.SInteger() => Right(ASchema(SchemaType.Integer))
case TSchemaType.SNumber() => Right(ASchema(SchemaType.Number))
case TSchemaType.SBoolean() => Right(ASchema(SchemaType.Boolean))
case TSchemaType.SString() => Right(ASchema(SchemaType.String))
case TSchemaType.SInteger() => ASchema(SchemaType.Integer)
case TSchemaType.SNumber() => ASchema(SchemaType.Number)
case TSchemaType.SBoolean() => ASchema(SchemaType.Boolean)
case TSchemaType.SString() => ASchema(SchemaType.String)
case p @ TSchemaType.SProduct(fields) =>
Right(
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields)
)
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields)
)
case TSchemaType.SArray(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
Right(ASchema(SchemaType.Array).copy(items = Some(Left(toSchemaReference.map(SchemaKey(nested, name))))))
case TSchemaType.SArray(el) => Right(ASchema(SchemaType.Array).copy(items = Some(apply(el))))
case TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) =>
Left(toSchemaReference.map(SchemaKey(nested, name)))
case TSchemaType.SOption(el) => apply(el, isOptionElement = true)
case TSchemaType.SBinary() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.Binary))
case TSchemaType.SDate() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.Date))
case TSchemaType.SDateTime() => Right(ASchema(SchemaType.String).copy(format = SchemaFormat.DateTime))
case TSchemaType.SRef(fullName) => Left(toSchemaReference.mapDirect(fullName))
ASchema(SchemaType.Array).copy(items = Some(toSchemaReference.map(nested, name)))
case TSchemaType.SArray(el) => ASchema(SchemaType.Array).copy(items = Some(apply(el)))
case TSchemaType.SOption(nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _)) => toSchemaReference.map(nested, name)
case TSchemaType.SOption(el) => apply(el, isOptionElement = true)
case TSchemaType.SBinary() => ASchema(SchemaType.String).copy(format = SchemaFormat.Binary)
case TSchemaType.SDate() => ASchema(SchemaType.String).copy(format = SchemaFormat.Date)
case TSchemaType.SDateTime() => ASchema(SchemaType.String).copy(format = SchemaFormat.DateTime)
case TSchemaType.SRef(fullName) => toSchemaReference.mapDirect(fullName)
case TSchemaType.SCoproduct(schemas, d) =>
Right(
ASchema
.apply(
schemas
.filterNot(_.hidden)
.map {
case nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => Left(toSchemaReference.map(SchemaKey(nested, name)))
case t => apply(t)
}
.sortBy {
case Left(Reference(ref, _, _)) => ref
case Right(schema) => schema.`type`.collect { case t: BasicSchemaType => t.value }.getOrElse("") + schema.toString
},
d.map(tDiscriminatorToADiscriminator)
)
ASchema.oneOf(
schemas
.filterNot(_.hidden)
.map {
case nested @ TSchema(_, Some(name), _, _, _, _, _, _, _, _, _) => toSchemaReference.map(nested, name)
case t => apply(t)
}
.sortBy {
case schema if schema.$ref.isDefined => schema.$ref.get
case schema => schema.`type`.collect { case t: BasicSchemaType => t.value }.getOrElse("") + schema.toString
},
d.map(tDiscriminatorToADiscriminator)
)
case p @ TSchemaType.SOpenProduct(fields, valueSchema) =>
Right(
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields),
additionalProperties = Some(SchemaKey(valueSchema) match {
case Some(key) => Left(toSchemaReference.map(key))
case _ => apply(valueSchema)
}).filterNot(_ => valueSchema.hidden)
)
ASchema(SchemaType.Object).copy(
required = p.required.map(_.encodedName),
properties = extractProperties(fields),
additionalProperties = Some(valueSchema.name match {
case Some(name) => toSchemaReference.map(valueSchema, name)
case _ => apply(valueSchema)
}).filterNot(_ => valueSchema.hidden)
)
}

Expand All @@ -70,20 +63,25 @@ private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, mar
case _ => false
}

result
.map(s => if (nullable) s.copy(nullable = Some(true)) else s)
.map(addMetadata(_, schema))
.map(addTitle(_, schema))
.map(addConstraints(_, primitiveValidators, schemaIsWholeNumber))
if (result.$ref.isEmpty) {
// only customising non-reference schemas; references might get enriched with some meta-data if there
// are multiple different customisations of the referenced schema in ToSchemaReference (#1203)
var s = result
s = if (nullable) s.copy(nullable = Some(true)) else s
s = addMetadata(s, schema)
s = addTitle(s, schema)
s = addConstraints(s, primitiveValidators, schemaIsWholeNumber)
s
} else result
}

private def extractProperties[T](fields: List[TSchemaType.SProductField[T]]) = {
fields
.filterNot(_.schema.hidden)
.map { f =>
SchemaKey(f.schema) match {
case Some(key) => f.name.encodedName -> Left(toSchemaReference.map(key))
case None => f.name.encodedName -> apply(f.schema)
f.schema.name match {
case Some(name) => f.name.encodedName -> toSchemaReference.map(f.schema, name)
case None => f.name.encodedName -> apply(f.schema)
}
}
.toListMap
Expand All @@ -95,8 +93,8 @@ private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, mar
private def addMetadata(oschema: ASchema, tschema: TSchema[_]): ASchema = {
oschema.copy(
description = tschema.description.orElse(oschema.description),
default = tschema.default.flatMap { case (_, raw) => raw.flatMap(r => exampleValue(tschema, r)) }.orElse(oschema.default),
example = tschema.encodedExample.flatMap(exampleValue(tschema, _)).orElse(oschema.example),
default = tDefaultToADefault(tschema).orElse(oschema.default),
example = tExampleToAExample(tschema).orElse(oschema.example),
format = tschema.format.orElse(oschema.format),
deprecated = (if (tschema.deprecated) Some(true) else None).orElse(oschema.deprecated),
extensions = DocsExtensions.fromIterable(tschema.docsExtensions)
Expand Down Expand Up @@ -155,8 +153,8 @@ private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, mar
private def tDiscriminatorToADiscriminator(discriminator: TSchemaType.SDiscriminator): Discriminator = {
val schemas = Some(
discriminator.mapping
.map { case (k, TSchemaType.SRef(fullName)) =>
k -> toSchemaReference.mapDiscriminator(fullName).$ref
.flatMap { case (k, TSchemaType.SRef(fullName)) =>
toSchemaReference.mapDiscriminator(fullName).$ref.map(k -> _)
}
.toList
.sortBy(_._1)
Expand All @@ -165,3 +163,10 @@ private[schema] class TSchemaToASchema(toSchemaReference: ToSchemaReference, mar
Discriminator(discriminator.name.encodedName, schemas)
}
}

object TSchemaToASchema {
def tDefaultToADefault(schema: TSchema[_]): Option[ExampleValue] = schema.default.flatMap { case (_, raw) =>
raw.flatMap(r => exampleValue(schema, r))
}
def tExampleToAExample(schema: TSchema[_]): Option[ExampleValue] = schema.encodedExample.flatMap(exampleValue(schema, _))
}
Loading

0 comments on commit 79717e1

Please sign in to comment.