Skip to content

Commit

Permalink
server : add TEI API format for /rerank endpoint (#11942)
Browse files Browse the repository at this point in the history
* server : add TEI API format for /rerank endpoint

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <[email protected]>

* fix

* also gitignore examples/server/*.gz.hpp

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
ngxson and ggerganov authored Feb 18, 2025
1 parent 5137da7 commit 63ac128
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ examples/server/*.css.hpp
examples/server/*.html.hpp
examples/server/*.js.hpp
examples/server/*.mjs.hpp
examples/server/*.gz.hpp
!build_64.sh
!examples/*.bat
!examples/*/*.kts
Expand Down
15 changes: 13 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4263,6 +4263,11 @@ int main(int argc, char ** argv) {
// return;
//}

// if true, use TEI API format, otherwise use Jina API format
// Jina: https://jina.ai/reranker/
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
bool is_tei_format = body.contains("texts");

json query;
if (body.count("query") == 1) {
query = body.at("query");
Expand All @@ -4275,7 +4280,8 @@ int main(int argc, char ** argv) {
return;
}

std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return;
Expand Down Expand Up @@ -4320,7 +4326,12 @@ int main(int argc, char ** argv) {
}

// write JSON response
json root = format_response_rerank(body, responses);
json root = format_response_rerank(
body,
responses,
is_tei_format,
documents);

res_ok(res, root);
};

Expand Down
38 changes: 32 additions & 6 deletions examples/server/tests/unit/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ def create_server():
server = ServerPreset.jina_reranker_tiny()


TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]


def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
"documents": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
Expand All @@ -38,6 +41,29 @@ def test_rerank():
assert least_relevant["index"] == 3


def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4

most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc

assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3


@pytest.mark.parametrize("documents", [
[],
None,
Expand Down
62 changes: 42 additions & 20 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,28 +737,50 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res;
}

static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & rank : ranks) {
data.push_back(json{
{"index", i++},
{"relevance_score", json_value(rank, "score", 0.0)},
});
static json format_response_rerank(
const json & request,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts) {
json res;
if (is_tei_format) {
// TEI response format
res = json::array();
bool return_text = json_value(request, "return_text", false);
for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
{"score", json_value(rank, "score", 0.0)},
};
if (return_text) {
elem["text"] = std::move(texts[index]);
}
res.push_back(elem);
}
} else {
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
for (const auto & rank : ranks) {
results.push_back(json{
{"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});

n_tokens += json_value(rank, "tokens_evaluated", 0);
}
n_tokens += json_value(rank, "tokens_evaluated", 0);
}

json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", data}
};
res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", results}
};
}

return res;
}
Expand Down

0 comments on commit 63ac128

Please sign in to comment.