forked from intel/openvino-rs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.rs
81 lines (70 loc) · 2.17 KB
/
util.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use core::cmp::Ordering;
use float_cmp::{ApproxEq, F32Margin};
/// A structure for holding the `(category, probability)` pair extracted from the output tensor of
/// the OpenVINO classification.
#[derive(Debug)]
pub struct Prediction {
id: usize,
prob: f32,
}
impl Prediction {
pub fn new(id: usize, prob: f32) -> Self {
Self { id, prob }
}
/// Reduce the boilerplate to assert that two predictions are approximately the same.
pub fn assert_approx_eq<P: Into<Self>>(&self, expected: P) {
let expected = expected.into();
assert_eq!(
self.id, expected.id,
"Expected class ID {} but found {}",
expected.id, self.id
);
let approx_matches = self.approx_eq(&expected, DEFAULT_MARGIN);
assert!(
approx_matches,
"Expected probability {} but found {} (outside of default margin of error)",
expected.prob, self.prob
);
}
}
impl From<(usize, f32)> for Prediction {
fn from(p: (usize, f32)) -> Self {
Prediction::new(p.0, p.1)
}
}
/// Classification results are ordered by their probability, from greatest to smallest.
impl Ord for Prediction {
fn cmp(&self, other: &Self) -> Ordering {
assert!(!self.prob.is_nan());
assert!(!other.prob.is_nan());
other
.prob
.partial_cmp(&self.prob)
.expect("a comparable value")
}
}
impl PartialOrd for Prediction {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for Prediction {
fn eq(&self, other: &Self) -> bool {
self.prob == other.prob
}
}
impl Eq for Prediction {}
impl ApproxEq for &Prediction {
type Margin = F32Margin;
fn approx_eq<T: Into<Self::Margin>>(self, other: Self, margin: T) -> bool {
let margin = margin.into();
self.prob.approx_eq(other.prob, margin)
}
}
/// The default margin for error allowed for comparing classification results.
pub const DEFAULT_MARGIN: F32Margin = F32Margin {
epsilon: 0.01,
ulps: 2,
};
/// A helper type for manipulating lists of results.
pub type Predictions = Vec<Prediction>;