#! /usr/bin/env python """ Main function is wsd_demo which has an obligatory evert file argument. This must be a sequence of events in the following line by line format:: BEGIN EVENT CLASS FEATURE1 VALUE1 FEATURE2 VALUE2 .... FEATUREN VALUEN END EVENT The convention being followed is that files in this special format have the extension '.evt' """ from nltk.classify import maxent def read_event_file(eventfile): fh = open(event_file,'r') event_list = [] sense = None event_begin = False ctr = 0 for line in fh: ctr += 1 line = line.strip() if line == 'BEGIN EVENT': event_begin = True feature_dict = {} continue else: if line == 'END EVENT': if sense: event_list.append((feature_dict,sense)) event_begin = False sense = None else: print 'Format error: line number %d' % (ctr,) break elif event_begin: sense = line event_begin = False else: (key,val) = line.split() if val == '0': val = False else: val = True feature_dict[key] = val return event_list def wsd_demo(event_file, trainer, n=1000, **cutoffs): import random events = read_event_file(event_file) print 'Reading data...' if n> len(events): n = len(events) senses = list(set(l for (i,l) in events)) print ' Senses: ' + ' '.join(senses) # Randomly split the names into a test & train set. print 'Splitting into test & train...' random.seed(123456) random.shuffle(events) train = events[:int(.8*n)] test = events[int(.8*n):n] # Train up a classifier. print 'Training classifier...' classifier = trainer( events, **cutoffs ) # Run the classifier on the test data. print 'Testing classifier...' acc = accuracy(classifier, test) print 'Accuracy: %6.4f' % acc # For classifiers that can find probabilities, show the log # likelihood and some sample probability distributions. try: pdists = classifier.batch_prob_classify(test) ll = [pdist.logprob(gold) for ((name, gold), pdist) in zip(test, pdists)] print 'Avg. log likelihood: %6.4f' % (sum(ll)/len(test)) except NotImplementedError: pass # Return the classifier return classifier if __name__ == '__main__': import sys if len(sys.argv) > 1: event_file = sys.argv[1] if len(sys.argv) > 2: max_iter = int(sys.argv[2]) else: max_iter = 50 else: print 'Usage: %s ' % (sys.argv[0],) sys.exit() wsd_demo(event_file,maxent.MaxentClassifier.train,4100,max_iter=max_iter)