blob: 185c79e5f67a7ca432e305c8fef23d97068c763c [file] [log] [blame]
// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use std::{
cmp::min,
collections::{hash_map::Iter, HashMap, VecDeque},
};
use serde::{Deserialize, Serialize};
#[derive(PartialEq, Clone, Debug, Serialize, Deserialize)]
pub struct NgramSet {
map: HashMap<String, u32>,
// once Rust supports it, it'd be nice to make this
// a type parameter & specialize
n: u8,
size: usize,
}
impl<'a> NgramSet {
pub fn new(n: u8) -> NgramSet {
NgramSet {
map: HashMap::new(),
n,
size: 0,
}
}
pub fn from_str(s: &str, n: u8) -> NgramSet {
let mut set = NgramSet::new(n);
set.analyze(s);
set
}
pub fn analyze(&mut self, s: &str) {
let words = s.split(' ');
let mut deque: VecDeque<&str> = VecDeque::with_capacity(self.n as usize);
for w in words {
deque.push_back(w);
if deque.len() == self.n as usize {
let parts = deque.iter().cloned().collect::<Vec<&str>>();
self.add_gram(parts.join(" "));
deque.pop_front();
}
}
}
fn add_gram(&mut self, gram: String) {
let n = self.map.entry(gram).or_insert(0);
*n += 1;
self.size += 1;
}
pub fn get(&self, gram: &str) -> u32 {
if let Some(count) = self.map.get(gram) {
*count
} else {
0
}
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn dice(&self, other: &NgramSet) -> f32 {
// no sense comparing sets of different sizes
if other.n != self.n {
return 0f32;
}
// there's obviously no match if either are empty strings;
// if we don't check here we could end up with NaN below
// when both are empty
if self.is_empty() || other.is_empty() {
return 0f32;
}
// choose the smaller map to iterate
let (x, y) = if self.len() < other.len() {
(self, other)
} else {
(other, self)
};
let mut matches = 0;
for (gram, count) in x {
matches += min(*count, y.get(gram));
}
(2.0 * matches as f32) / ((self.len() + other.len()) as f32)
}
}
impl<'a> IntoIterator for &'a NgramSet {
type Item = (&'a String, &'a u32);
type IntoIter = Iter<'a, String, u32>;
fn into_iter(self) -> Self::IntoIter {
self.map.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
// this is a pretty banal test, but it's a starting point :P
#[test]
fn can_construct() {
let set = NgramSet::new(2);
assert_eq!(set.size, 0);
assert_eq!(set.n, 2);
}
#[test]
fn no_nan() {
let a = NgramSet::from_str("", 2);
let b = NgramSet::from_str("", 2);
let score = a.dice(&b);
assert!(!score.is_nan());
}
#[test]
fn same_size() {
let a = NgramSet::from_str("", 2);
let b = NgramSet::from_str("", 3);
let score = a.dice(&b);
assert_eq!(0f32, score);
}
#[test]
fn identical() {
let a = NgramSet::from_str("one two three apple banana", 2);
let b = NgramSet::from_str("one two three apple banana", 2);
let score = a.dice(&b);
assert_eq!(1f32, score);
}
}