Source code for nltk.chunk.named_entity

# Natural Language Toolkit: Chunk parsing API
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Edward Loper <edloper@gmail.com>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

"""
Named entity chunker
"""

import os
import pickle
import re
from xml.etree import ElementTree as ET

from nltk.tag import ClassifierBasedTagger, pos_tag

try:
    from nltk.classify import MaxentClassifier
except ImportError:
    pass

from nltk.chunk.api import ChunkParserI
from nltk.chunk.util import ChunkScore
from nltk.data import find
from nltk.tokenize import word_tokenize
from nltk.tree import Tree


[docs]class NEChunkParserTagger(ClassifierBasedTagger): """ The IOB tagger used by the chunk parser. """
[docs] def __init__(self, train): ClassifierBasedTagger.__init__( self, train=train, classifier_builder=self._classifier_builder )
def _classifier_builder(self, train): return MaxentClassifier.train( train, algorithm="megam", gaussian_prior_sigma=1, trace=2 ) def _english_wordlist(self): try: wl = self._en_wordlist except AttributeError: from nltk.corpus import words self._en_wordlist = set(words.words("en-basic")) wl = self._en_wordlist return wl def _feature_detector(self, tokens, index, history): word = tokens[index][0] pos = simplify_pos(tokens[index][1]) if index == 0: prevword = prevprevword = None prevpos = prevprevpos = None prevshape = prevtag = prevprevtag = None elif index == 1: prevword = tokens[index - 1][0].lower() prevprevword = None prevpos = simplify_pos(tokens[index - 1][1]) prevprevpos = None prevtag = history[index - 1][0] prevshape = prevprevtag = None else: prevword = tokens[index - 1][0].lower() prevprevword = tokens[index - 2][0].lower() prevpos = simplify_pos(tokens[index - 1][1]) prevprevpos = simplify_pos(tokens[index - 2][1]) prevtag = history[index - 1] prevprevtag = history[index - 2] prevshape = shape(prevword) if index == len(tokens) - 1: nextword = nextnextword = None nextpos = nextnextpos = None elif index == len(tokens) - 2: nextword = tokens[index + 1][0].lower() nextpos = tokens[index + 1][1].lower() nextnextword = None nextnextpos = None else: nextword = tokens[index + 1][0].lower() nextpos = tokens[index + 1][1].lower() nextnextword = tokens[index + 2][0].lower() nextnextpos = tokens[index + 2][1].lower() # 89.6 features = { "bias": True, "shape": shape(word), "wordlen": len(word), "prefix3": word[:3].lower(), "suffix3": word[-3:].lower(), "pos": pos, "word": word, "en-wordlist": (word in self._english_wordlist()), "prevtag": prevtag, "prevpos": prevpos, "nextpos": nextpos, "prevword": prevword, "nextword": nextword, "word+nextpos": f"{word.lower()}+{nextpos}", "pos+prevtag": f"{pos}+{prevtag}", "shape+prevtag": f"{prevshape}+{prevtag}", } return features
[docs]class NEChunkParser(ChunkParserI): """ Expected input: list of pos-tagged words """
[docs] def __init__(self, train): self._train(train)
[docs] def parse(self, tokens): """ Each token should be a pos-tagged word """ tagged = self._tagger.tag(tokens) tree = self._tagged_to_parse(tagged) return tree
def _train(self, corpus): # Convert to tagged sequence corpus = [self._parse_to_tagged(s) for s in corpus] self._tagger = NEChunkParserTagger(train=corpus) def _tagged_to_parse(self, tagged_tokens): """ Convert a list of tagged tokens to a chunk-parse tree. """ sent = Tree("S", []) for (tok, tag) in tagged_tokens: if tag == "O": sent.append(tok) elif tag.startswith("B-"): sent.append(Tree(tag[2:], [tok])) elif tag.startswith("I-"): if sent and isinstance(sent[-1], Tree) and sent[-1].label() == tag[2:]: sent[-1].append(tok) else: sent.append(Tree(tag[2:], [tok])) return sent @staticmethod def _parse_to_tagged(sent): """ Convert a chunk-parse tree to a list of tagged tokens. """ toks = [] for child in sent: if isinstance(child, Tree): if len(child) == 0: print("Warning -- empty chunk in sentence") continue toks.append((child[0], f"B-{child.label()}")) for tok in child[1:]: toks.append((tok, f"I-{child.label()}")) else: toks.append((child, "O")) return toks
[docs]def shape(word): if re.match(r"[0-9]+(\.[0-9]*)?|[0-9]*\.[0-9]+$", word, re.UNICODE): return "number" elif re.match(r"\W+$", word, re.UNICODE): return "punct" elif re.match(r"\w+$", word, re.UNICODE): if word.istitle(): return "upcase" elif word.islower(): return "downcase" else: return "mixedcase" else: return "other"
[docs]def simplify_pos(s): if s.startswith("V"): return "V" else: return s.split("-")[0]
[docs]def postag_tree(tree): # Part-of-speech tagging. words = tree.leaves() tag_iter = (pos for (word, pos) in pos_tag(words)) newtree = Tree("S", []) for child in tree: if isinstance(child, Tree): newtree.append(Tree(child.label(), [])) for subchild in child: newtree[-1].append((subchild, next(tag_iter))) else: newtree.append((child, next(tag_iter))) return newtree
[docs]def load_ace_data(roots, fmt="binary", skip_bnews=True): for root in roots: for root, dirs, files in os.walk(root): if root.endswith("bnews") and skip_bnews: continue for f in files: if f.endswith(".sgm"): yield from load_ace_file(os.path.join(root, f), fmt)
[docs]def load_ace_file(textfile, fmt): print(f" - {os.path.split(textfile)[1]}") annfile = textfile + ".tmx.rdc.xml" # Read the xml file, and get a list of entities entities = [] with open(annfile) as infile: xml = ET.parse(infile).getroot() for entity in xml.findall("document/entity"): typ = entity.find("entity_type").text for mention in entity.findall("entity_mention"): if mention.get("TYPE") != "NAME": continue # only NEs s = int(mention.find("head/charseq/start").text) e = int(mention.find("head/charseq/end").text) + 1 entities.append((s, e, typ)) # Read the text file, and mark the entities. with open(textfile) as infile: text = infile.read() # Strip XML tags, since they don't count towards the indices text = re.sub("<(?!/?TEXT)[^>]+>", "", text) # Blank out anything before/after <TEXT> def subfunc(m): return " " * (m.end() - m.start() - 6) text = re.sub(r"[\s\S]*<TEXT>", subfunc, text) text = re.sub(r"</TEXT>[\s\S]*", "", text) # Simplify quotes text = re.sub("``", ' "', text) text = re.sub("''", '" ', text) entity_types = {typ for (s, e, typ) in entities} # Binary distinction (NE or not NE) if fmt == "binary": i = 0 toks = Tree("S", []) for (s, e, typ) in sorted(entities): if s < i: s = i # Overlapping! Deal with this better? if e <= s: continue toks.extend(word_tokenize(text[i:s])) toks.append(Tree("NE", text[s:e].split())) i = e toks.extend(word_tokenize(text[i:])) yield toks # Multiclass distinction (NE type) elif fmt == "multiclass": i = 0 toks = Tree("S", []) for (s, e, typ) in sorted(entities): if s < i: s = i # Overlapping! Deal with this better? if e <= s: continue toks.extend(word_tokenize(text[i:s])) toks.append(Tree(typ, text[s:e].split())) i = e toks.extend(word_tokenize(text[i:])) yield toks else: raise ValueError("bad fmt value")
# This probably belongs in a more general-purpose location (as does # the parse_to_tagged function).
[docs]def cmp_chunks(correct, guessed): correct = NEChunkParser._parse_to_tagged(correct) guessed = NEChunkParser._parse_to_tagged(guessed) ellipsis = False for (w, ct), (w, gt) in zip(correct, guessed): if ct == gt == "O": if not ellipsis: print(f" {ct:15} {gt:15} {w}") print(" {:15} {:15} {2}".format("...", "...", "...")) ellipsis = True else: ellipsis = False print(f" {ct:15} {gt:15} {w}")
[docs]def build_model(fmt="binary"): print("Loading training data...") train_paths = [ find("corpora/ace_data/ace.dev"), find("corpora/ace_data/ace.heldout"), find("corpora/ace_data/bbn.dev"), find("corpora/ace_data/muc.dev"), ] train_trees = load_ace_data(train_paths, fmt) train_data = [postag_tree(t) for t in train_trees] print("Training...") cp = NEChunkParser(train_data) del train_data print("Loading eval data...") eval_paths = [find("corpora/ace_data/ace.eval")] eval_trees = load_ace_data(eval_paths, fmt) eval_data = [postag_tree(t) for t in eval_trees] print("Evaluating...") chunkscore = ChunkScore() for i, correct in enumerate(eval_data): guess = cp.parse(correct.leaves()) chunkscore.score(correct, guess) if i < 3: cmp_chunks(correct, guess) print(chunkscore) outfilename = f"/tmp/ne_chunker_{fmt}.pickle" print(f"Saving chunker to {outfilename}...") with open(outfilename, "wb") as outfile: pickle.dump(cp, outfile, -1) return cp
if __name__ == "__main__": # Make sure that the pickled object has the right class name: from nltk.chunk.named_entity import build_model build_model("binary") build_model("multiclass")