Statistics For Programmers - Naive Bayes Classifier
Putting together everything we've learned thus far, let's build a Naive Bayes Classifier. It is a simple probabilistic classifier that applyies Bayes' theorem. It's a popular choice for text classification problems because of its simplicity and efficiency however it does assume some independence between the features being classified making it a "naive" form of classification.
A common usecase for Naive Bayes classification is spam detection, which is what we'll be buidling out. The classifier will be trained on a small dataset of spam and non-spam emails and ultimately be used to make predictions against new emails.
Preparing the Data
For this project, we'll be using a handcrafted and annotated dataset of spam and non-spam emails sourced from my emails (Thanks spammers!). I just randomly looked through my emails to pick out a few examples but encourage you to try this on larger datasets you might be able to get your hands on.
const trainData = [
{ email: "free viagra now!!!", spam: true },
{ email: "free lottery ticket", spam: true },
{ email: "get your free lottery ticket right now", spam: true },
{ email: "skinny pills, melt your fat away", spam: true },
{ email: "get your free pills now", spam: true },
{ email: "free pills to lose weight", spam: true },
{ email: "you haven't claimed your free bitcoin", spam: true },
{ email: "your free bitcoin is waiting for you", spam: true },
{
email: "[Action Required] Your free crypto is about to expire",
spam: true,
},
{ email: "Last chance to claim 0.1 eth", spam: true },
{ email: "Architecture Meeting at 3pm", spam: false },
{
email: "[Action Requrired] Timecards need to be submitted by 9 am",
spam: false,
},
{ email: "Please review the attached document", spam: false },
{ email: "Your Amazon order has been shipped", spam: false },
{ email: "Invitation: 1:1 with John Doe", spam: false },
{ email: "Phoenix kickoff meeting moved to 04/15", spam: false },
{ email: "Program Executive report", spam: false },
{ email: "Jane Doe has shared a calendar with you", spam: false },
{ email: "Project X Norming meeting notes", spam: false },
{ email: "Your order has been confirmed", spam: false },
];
The training data is formatted as a list of objects where each object has an email
field containing the test content of the email and a spam
field represeting the class that the email falls into (spam or not spam).
We want our sample here to be representative of the overall population of emails. This is important because the classifier will be making predictions based on the training data. A bit of research suggested that the average spam rate in emails is around 50%[1]. So, we want our training data to have a similar distribution. This is particularly important if you are generating synthetic data or sourcing from an external source. If the training data is not representative, the classifier may not perform well on new data.
Building our Vocabulary
The first step is to tokenize our dataset. Tokenization is the process of breaking down a text into individual words or phrases. We'll be using a simple tokenization function that splits the email text into individual words.
function tokenize(email) {
return email.toLowerCase().split(/\s*\b\s*/);
}
We're converting the emails to lowercase as a normalization step because we want the classifier to treat "Free" and "free" as the same word. Additionally, we're spliting on word boundaries to extract individual words while accounting for common punctuation. It's imperfect since it doesn't account for words with punctuation in them such as "don't" but it's a good starting point.
Applying this to our training data, we can build a vocabulary of all the unique words in our dataset.
let vocabulary = new Set();
trainData.forEach(({ email }) => {
const words = tokenize(email);
words.forEach((word) => vocabulary.add(word));
});
vocabulary = Array.from(vocabulary);
console.log(vocabulary);
// [
// 'free', 'viagra', 'now', '!!!', 'lottery',
// 'ticket', 'get', 'your', 'right', 'skinny',
// 'pills', ',', 'melt', 'fat', 'away',
Now that we have our vocabulary, we can move on to building our word frequencies. This will be a dictionary where the keys are words and the values are the number of times that word appears in spam and non-spam emails.
const trainSpam = trainData
.filter((email) => email.spam)
.map((email) => email.email)
.join(" ")
.toLowerCase();
const wordFreqSpam = vocabulary.reduce((acc, word) => {
acc[word] = trainSpam.split(word).length - 1;
return acc;
}, {});
console.log(wordFreqSpam);
// {
// '0': 1,
// '1': 1,
// '9': 0,
// '15': 0,
// free: 8,
// viagra: 1,
// now: 3,
// '!!!': 1,
// lottery: 2,
// ticket: 2,
// get: 2,
// your: 6,
// right: 1,
We repeat the same process for non-spam emails.
const trainNotSpam = trainData
.filter((email) => !email.spam)
.map((email) => email.email)
.join(" ")
.toLowerCase();
const wordFreqNotSpam = vocabulary.reduce((acc, word) => {
acc[word] = trainNotSpam.split(word).length - 1;
return acc;
}, {});
console.log(wordFreqNotSpam);
// {
// '0': 1,
// '1': 3,
// '9': 1,
// '15': 1,
// free: 0,
// viagra: 0,
// now: 0,
// '!!!': 0,
// lottery: 0,
// ticket: 0,
// get: 0,
// your: 2,
// right: 0,
We can begin to see patterns emerge in the word frequencies. For example, the word "free" appears 8 times in spam emails but 0 times in non-spam emails. This is a feature that our classifier will pick up on and use to make predictions.
Making Predictions
With our vocabulary and word frequencies in place, we have the necessary components to make predictions. We can compute the "spamminess" (or "spamicity") of a word in our vocabulary by calculating the ratio of the word frequency in spam emails to the word frequency in non-spam emails.
const totalSpam = trainSpam.length;
const wordProbSpam = Object.keys(wordFreqSpam).reduce((acc, word) => {
acc[word] = (wordFreqSpam[word] + 1) / (totalSpam + 2);
return acc;
}, {});
console.log(wordProbSpam);
// {
// '0': 0.00625,
// '1': 0.00625,
// '9': 0.003125,
// '15': 0.003125,
// free: 0.028125,
// viagra: 0.00625,
// now: 0.0125,
// '!!!': 0.00625,
// lottery: 0.009375,
// ticket: 0.009375,
// get: 0.009375,
// your: 0.021875,
// right: 0.00625,
// skinny: 0.00625,
// pills: 0.0125,
// ',': 0.00625,
// melt: 0.00625,
// fat: 0.00625,
We can see that the word "free" has a spam probability of 0.028125. This means that the word "free" is 2.8% more likely to appear in spam emails than non-spam emails based on our training data.
We repeat the same process for non-spam emails.
const totalNotSpam = trainNotSpam.length;
const wordProbNotSpam = Object.keys(wordFreqNotSpam).reduce((acc, word) => {
acc[word] = (wordFreqNotSpam[word] + 1) / (totalNotSpam + 2);
return acc;
}, {});
console.log(wordProbNotSpam);
// {
// '0': 0.005649717514124294,
// '1': 0.011299435028248588,
// '9': 0.005649717514124294,
// '15': 0.005649717514124294,
// free: 0.002824858757062147,
// viagra: 0.002824858757062147,
// now: 0.002824858757062147,
// '!!!': 0.002824858757062147,
// lottery: 0.002824858757062147,
// ticket: 0.002824858757062147,
// get: 0.002824858757062147,
// your: 0.00847457627118644,
// right: 0.002824858757062147,
// skinny: 0.002824858757062147,
// pills: 0.002824858757062147,
// ',': 0.002824858757062147,
// melt: 0.002824858757062147,
// fat: 0.002824858757062147,
Next we want to compute the probability of a given word being spam and not spam.
const totalEmails = totalSpam + totalNotSpam;
const probSpam = totalSpam / totalEmails;
const probNotSpam = totalNotSpam / totalEmails;
const probWord = (word) => {
// This is the probability of a word appearing in an email
return wordProbSpam[word] * probSpam + wordProbNotSpam[word] * probNotSpam;
};
const probSpamGivenWord = (word) => {
// This is the probability of a word being spam given that it appears in an email
return (wordProbSpam[word] * probSpam) / probWord(word);
};
console.log(probWord("free")); // 0.014832985496247576
console.log(probSpamGivenWord("free")); // 0.8999456380775066
The probability of the word "free" appearing in an email is 0.0148. The probability of the word "free" being spam given that it appears in an email is 0.8999. This means that if the word "free" appears in an email, there is a 90% chance that the email is spam based on our training data.
Putting it all together
We can now use this to classify an email as spam or not spam based on the probability of the email being spam. For this demonstration, we'll be using a simple threshold of 0.5. If the probability is greater than or equal to 0.5, we'll classify the email as spam, otherwise we'll classify it as not spam. In practice you might want to experiment with different thresholds to see what works best for your data.
const classifyEmail = (email, threshold = 0.5) => {
let words = email.toLowerCase().split(/\s*\b\s*/);
// ignroe words that are not in the vocabulary
words = words.filter((word) => vocabulary.includes(word));
const probSpamGivenWords = words.map((word) => probSpamGivenWord(word));
// take the argmax of the probabilities
const argmax = Math.max(...probSpamGivenWords);
return {
spam: argmax >= threshold,
prob: argmax,
};
};
console.log(classifyEmail("get free crypto wallet!!!"));
// { spam: true, prob: 0.8999456380775066 }
The email "get free crypto wallet!!!" is classified as spam with a probability of 0.8999. This is consistent with our expectations based on the training data.
Testing the Classifier
To test the classifier, I generated a synthetic dataset of 10 emails using ChatGPT.
const testData = [
{ email: "Get your free trial today!", spam: true },
{ email: "You've won a million dollars!", spam: true },
{ email: "Limited time offer: Buy one get one free!", spam: true },
{
email: "Congratulations! You've been selected for a special discount.",
spam: true,
},
{ email: "Exclusive deal: 50% off all purchases!", spam: true },
{ email: "Meeting agenda for next week", spam: false },
{ email: "Reminder: Payment due tomorrow", spam: false },
{ email: "Invitation to collaborate on a project", spam: false },
{ email: "Monthly newsletter: April edition", spam: false },
{ email: "Thank you for your recent purchase", spam: false },
];
We can now use our classifier to make predictions on the test data and compute the mean accuracy.
const testResults = testData.map((email) => {
const result = classifyEmail(email.email);
return {
email: email.email,
spam: email.spam,
predicted: result.spam,
prob: result.prob,
};
});
const accuracy =
testResults.reduce((acc, email) => {
if (email.spam === email.predicted) {
acc++;
}
return acc;
}, 0) / testData.length;
console.log(accuracy); // 0.6
We managed to achieve an accuracy of 60% on our test data. This is a good starting point considering the limited training data however there's clearly a lot of room for improvement.
Lever, R. (2022) What spam email is and how to stop it. www.usnews.com. Available at: https://www.usnews.com/360-reviews/privacy/what-spam-email-is (Accessed: 2024-4-20). ↩︎