• For any query, contact us at
  • +91-9872993883
  • +91-8283824812
  • info@ris-ai.com

Text Classification

clear
clc
emb = fastTextWordEmbedding;
filenameTrain = "weatherReportsTrain.csv";
textName = "event_narrative";
labelName = "event_type";
ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);
ttdsTrain.ReadSize = 8;
preview(ttdsTrain)
ans = 8×2 table
  event_narrative event_type
1 'Large tree down between Plantersville and Nettleton.' 'Thunderstorm Wind'
2 'One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water.' 'Heavy Rain'
3 'NWS Columbia relayed a report of trees blown down along Tom Hall St.' 'Thunderstorm Wind'
4 'Media reported two trees blown down along I-40 in the Old Fort area.' 'Thunderstorm Wind'
5 'A few tree limbs greater than 6 inches down on HWY 18 in Roseland.' 'Thunderstorm Wind'
6 'Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins.' 'Thunderstorm Wind'
7 'Tin roof ripped off house on Old Memphis Road near Billings Drive. Several large trees down in the area.' 'Thunderstorm Wind'
8 'Powerlines down at Walnut Grove and Cherry Lane roads.' 'Thunderstorm Wind'
labels = readLabels(ttdsTrain,labelName);
classNames = unique(labels);
numObservations = numel(labels);
sequenceLength = 100; tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTrain =
TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)} IncludeInfo: 0
preview(tdsTrain)
ans = 8×2 table
s.no predictors responses
1 1×100×300 single Thunderstorm Wind
2 1×100×300 single Heavy Rain
3 1×100×300 single Thunderstorm Wind
4 1×100×300 single Thunderstorm Wind
5 1×100×300 single Thunderstorm Wind
6 1×100×300 single Thunderstorm Wind
7 1×100×300 single Thunderstorm Wind
8 1×100×300 single Thunderstorm Wind
filenameValidation = "weatherReportsValidation.csv";
ttdsValidation = tabularTextDatastore(filenameValidation,'SelectedVariableNames',[textName labelName]);
tdsValidation = transform(ttdsValidation, @(data) transformTextData(data,sequenceLength,emb,classNames));
numFeatures = emb.Dimension; inputSize = [1 sequenceLength numFeatures]; numFilters = 200;
ngramLengths = [2 3 4 5 6];
numBlocks = numel(ngramLengths);
numClasses = numel(classNames);
layer = imageInputLayer(inputSize,'Normalization','none','Name','input');
lgraph = layerGraph(layer);
%create a block of convolution
% a block
% block of
% of convolution
% For each of the n-gram lengths, create a block of convolution, batch normalization,
% ReLU, dropout, and max pooling layers. Connect each block to the input layer.
for j = 1:numBlocks
N = ngramLengths(j);
block = [
convolution2dLayer([1 N],numFilters,'Name',"conv"+N,'Padding','same')
batchNormalizationLayer('Name',"bn"+N)
reluLayer('Name',"relu"+N)
dropoutLayer(0.2,'Name',"drop"+N)
maxPooling2dLayer([1 sequenceLength],'Name',"max"+N)];
lgraph = addLayers(lgraph,block);
lgraph = connectLayers(lgraph,'input',"conv"+N);
end
View the network architecture in a plot.
figure
plot(lgraph)
title("Network Architecture")
Network Architecture
layers = [
depthConcatenationLayer(numBlocks,'Name','depth')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','soft')
classificationLayer('Name','classification')];
lgraph = addLayers(lgraph,layers);
figure
plot(lgraph)
title("Network Architecture")
Network Architecture
for j = 1:numBlocks
N = ngramLengths(j);
lgraph = connectLayers(lgraph,"max"+N,"depth/in"+j);
end
figure
plot(lgraph)
title("Network Architecture")
Network Architecture

Training Network

miniBatchSize = 128;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'MiniBatchSize',miniBatchSize, ...
'ValidationData',tdsValidation, ...
'ValidationFrequency',numIterationsPerEpoch, ...
'Plots','training-progress', ...
'Verbose',false);
net = trainNetwork(tdsTrain,lgraph,options);
Warning: Input datastore is not Shuffleable but trainingOptions specified shuffling. Training will proceed without shuffling.
Warning: Input datastore is not Shuffleable but trainingOptions specified shuffling. Training will proceed without shuffling.
Network Architecture

Testing Network

filenameTest = "weatherReportsTest.csv";
ttdsTest = tabularTextDatastore(filenameTest,'SelectedVariableNames',[textName labelName]);
tdsTest = transform(ttdsTest, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTest =
TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)} IncludeInfo: 0
labelsTest = readLabels(ttdsTest,labelName);
YTest = categorical(labelsTest,classNames);
YPred = classify(net,tdsTest);
accuracy = mean(YPred == YTest)
accuracy = 0.885463599715438
Copyright belongs to RIS.

Resources You Will Ever Need