8000 Adding `Sequence` for `PostProcessor`. by Narsil · Pull Request #1052 · huggingface/tokenizers · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Adding Sequence for PostProcessor. #1052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bindings/node/lib/bindings/post-processors.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ export function templateProcessing(
pair?: string,
specialTokens?: [string, number][]
): PostProcessor;

/**
* Instantiate a new SequenceProcessing.
*
* @param PostProcessor[] The list of Processors to use
* @since 0.13.0
*/
export function sequenceProcessing(processors: PostProcessor[]): PostProcessor;
1 change: 1 addition & 0 deletions bindings/node/lib/bindings/post-processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ module.exports = {
byteLevelProcessing: native.processors_ByteLevel,
robertaProcessing: native.processors_RobertaProcessing,
templateProcessing: native.processors_TemplateProcessing,
sequenceProcessing: native.processors_Sequence,
};
12 changes: 12 additions & 0 deletions bindings/node/lib/bindings/post-processors.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
bertProcessing,
byteLevelProcessing,
robertaProcessing,
sequenceProcessing,
templateProcessing,
} from "./post-processors";

Expand Down Expand Up @@ -81,3 +82,14 @@ describe("templateProcessing", () => {
expect(processor.constructor.name).toEqual("Processor");
});
});

describe("sequenceProcessing", () => {
it("accepts `PostProcessor[]` as first parameter", () => {
const template = templateProcessing("[CLS] $A [SEP]", "[CLS] $A [SEP] $B:1 [SEP]:1", [
["[CLS]", 1],
["[SEP]", 2],
]);
const bytelevel = byteLevelProcessing(true);
expect(sequenceProcessing([bytelevel, template])).toBeDefined();
});
});
28 changes: 28 additions & 0 deletions bindings/node/native/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,33 @@ fn template_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
Ok(js_processor)
}

/// sequence(processors: List[Processor])
fn sequence(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
let processors = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let mut sequence = Vec::with_capacity(processors.len());

processors.into_iter().try_for_each(|processor| {
match processor.downcast::<JsPostProcessor>().or_throw(&mut cx) {
Ok(processor) => {
let guard = cx.lock();
if let Some(processor_arc) = &processor.borrow(&guard).processor {
let processor: PostProcessorWrapper = (**processor_arc).clone();
sequence.push(processor);
}
Ok(())
}
Err(e) => Err(e),
}
})?;

let mut pretok = JsPostProcessor::new::<_, JsPostProcessor, _>(&mut cx, vec![])?;
let guard = cx.lock();
pretok.borrow_mut(&guard).processor = Some(Arc::new(PostProcessorWrapper::Sequence(
tk::processors::sequence::Sequence::new(sequence),
)));
Ok(pretok)
}

/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_BertProcessing", prefix), bert_processing)?;
Expand All @@ -138,5 +165,6 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
&format!("{}_TemplateProcessing", prefix),
template_processing,
)?;
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
Ok(())
}
1 change: 1 addition & 0 deletions bindings/python/py_src/tokenizers/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
BertProcessing = processors.BertProcessing
ByteLevel = processors.ByteLevel
RobertaProcessing = processors.RobertaProcessing
Sequence = processors.Sequence
TemplateProcessing = processors.TemplateProcessing
42 changes: 42 additions & 0 deletions bindings/python/py_src/tokenizers/processors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,48 @@ class RobertaProcessing(PostProcessor):
"""
pass

class Sequence(PostProcessor):
"""
Sequence Processor

Args:
processors (:obj:`List[PostProcessor]`)
The processors that need to be chained
"""

def __init__(self, processors):
pass
def num_special_tokens_to_add(self, is_pair):
"""
Return the number of special tokens that would be added for single/pair sentences.

Args:
is_pair (:obj:`bool`):
Whether the input would be a pair of sequences

Returns:
:obj:`int`: The number of tokens to add
"""
pass
def process(self, encoding, pair=None, add_special_tokens=True):
"""
Post-process the given encodings, generating the final one

Args:
encoding (:class:`~tokenizers.Encoding`):
The encoding for the first sequence

pair (:class:`~tokenizers.Encoding`, `optional`):
The encoding for the pair sequence

add_special_tokens (:obj:`bool`):
Whether to add the special tokens

Return:
:class:`~tokenizers.Encoding`: The final encoding
"""
pass

class TemplateProcessing(PostProcessor):
"""
Provides a way to specify templates in order to add the special tokens to each
Expand Down
1 change: 1 addition & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<processors::PyRobertaProcessing>()?;
m.add_class::<processors::PyByteLevel>()?;
m.add_class::<processors::PyTemplateProcessing>()?;
m.add_class::<processors::PySequence>()?;
Ok(())
}

Expand Down
33 changes: 33 additions & 0 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use tk::processors::bert::BertProcessing;
use tk::processors::byte_level::ByteLevel;
use tk::processors::roberta::RobertaProcessing;
use tk::processors::sequence::Sequence;
use tk::processors::template::{SpecialToken, Template};
use tk::processors::PostProcessorWrapper;
use tk::{Encoding, PostProcessor};
Expand Down Expand Up @@ -50,6 +51,7 @@ impl PyPostProcessor {
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
})
}
}
Expand Down Expand Up @@ -414,6 +416,37 @@ impl PyTemplateProcessing {
}
}

/// Sequence Processor
///
/// Args:
/// processors (:obj:`List[PostProcessor]`)
/// The processors that need to be chained
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")]
#[pyo3(text_signature = "(self, processors)")]
pub struct PySequence {}
#[pymethods]
impl PySequence {
#[new]
#[args(processors)]
fn new(processors_py: &PyList) -> (Self, PyPostProcessor) {
let mut processors: Vec<PostProcessorWrapper> = Vec::with_capacity(processors_py.len());
for n in processors_py.iter() {
let processor: PyRef<PyPostProcessor> = n.extract().unwrap();
let processor = processor.processor.as_ref();
processors.push(processor.clone());
}
let sequence_processor = Sequence::new(processors);
(
PySequence {},
PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))),
)
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, &[PyList::empty(py)])
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down
47 changes: 47 additions & 0 deletions bindings/python/tests/bindings/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
RobertaProcessing,
ByteLevel,
TemplateProcessing,
Sequence,
)


Expand Down Expand Up @@ -179,3 +180,49 @@ def test_roberta_parity(self):
tokenizer.post_processor = self.get_roberta()
template = tokenizer.encode("my name is john", "pair")
assert original.ids == template.ids


class TestSequenceProcessing:
def test_sequence_processing(self):
assert Sequence([]) is not None
assert Sequence([ByteLevel()]) is not None
assert isinstance(Sequence([]), PostProcessor)
assert isinstance(Sequence([]), Sequence)
serialized = pickle.dumps(Sequence([]))
assert isinstance(pickle.loads(serialized), Sequence)

def test_post_process(self):
byte_level = ByteLevel(trim_offsets=True)
template = TemplateProcessing(
single=["[CLS]", "$0", "[SEP]"],
pair=["[CLS]:0", "$A", "[SEP]:0", "$B:1", "[SEP]:1"],
special_tokens=[("[CLS]", 1), ("[SEP]", 0)],
)

tokenizer = Tokenizer(BPE())
tokenizer.add_special_tokens(["[SEP]", "[CLS]"])
tokenizer.add_tokens(["my", "name", "is", "Ġjohn", "pair"])
tokenizer.post_processor = template

# Before the sequence
original = tokenizer.encode("my name is Ġjohn")
assert original.ids == [1, 2, 3, 4, 5, 0]
assert original.type_ids == [0, 0, 0, 0, 0, 0]
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0)]
pair = tokenizer.encode("my name is Ġjohn", "pair")
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (11, 16), (0, 0), (0, 4), (0, 0)]

processor = Sequence([byte_level, template])
tokenizer.post_processor = processor

original = tokenizer.encode("my name is Ġjohn")
assert original.ids == [1, 2, 3, 4, 5, 0]
assert original.type_ids == [0, 0, 0, 0, 0, 0]
# Offsets ARE trimmed
assert original.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0)]
pair = tokenizer.encode("my name is Ġjohn", "pair")
# assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0]
assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1]
assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)]
33 changes: 26 additions & 7 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl PostProcessor for ByteLevel {
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
_add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
for encoding in encodings.iter_mut() {
Expand All @@ -188,7 +188,11 @@ impl PostProcessor for ByteLevel {
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
}
}
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
for (i, encoding) in encodings.iter_mut().enumerate() {
encoding.set_sequence_id(i);
}
Ok(encodings)
//<dyn PostProcessor>::default_process(encodings, add_special_tokens)
}
}

Expand Down Expand Up @@ -493,7 +497,7 @@ mod tests {
vec![],
vec![],
vec![],
HashMap::new(),
HashMap::from_iter(vec![(0, 0..5)]),
);

let bytelevel = ByteLevel::default().trim_offsets(true);
Expand All @@ -502,24 +506,39 @@ mod tests {
bytelevel.process(start.clone(), None, false).unwrap()
);

let mut pair_expected = Encoding::new(
vec![0; 5],
let pair_expected = Encoding::new(
vec![0; 10],
vec![],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
vec![
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
);
pair_expected.merge_with(expected, false);
assert_eq!(
pair_expected,
bytelevel
Expand Down
6 changes: 6 additions & 0 deletions tokenizers/src/processors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod bert;
pub mod roberta;
pub mod sequence;
pub mod template;

// Re-export these as processors
Expand All @@ -10,6 +11,7 @@ use serde::{Deserialize, Serialize};
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::processors::bert::BertProcessing;
use crate::processors::roberta::RobertaProcessing;
use crate::processors::sequence::Sequence;
use crate::processors::template::TemplateProcessing;
use crate::{Encoding, PostProcessor, Result};

Expand All @@ -21,6 +23,7 @@ pub enum PostProcessorWrapper {
Bert(BertProcessing),
ByteLevel(ByteLevel),
Template(TemplateProcessing),
Sequence(Sequence),
}

impl PostProcessor for PostProcessorWrapper {
Expand All @@ -30,6 +33,7 @@ impl PostProcessor for PostProcessorWrapper {
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
Self::Template(template) => template.added_tokens(is_pair),
Self::Sequence(bl) => bl.added_tokens(is_pair),
}
}

Expand All @@ -43,6 +47,7 @@ impl PostProcessor for PostProcessorWrapper {
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
}
}
}
Expand All @@ -51,6 +56,7 @@ impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);

#[cfg(test)]
mod tests {
Expand Down
Loading
0