我已经检查了 dl4j 示例并执行了AnimalsClassification示例以成功测试。
我必须训练、评估和预测(使用)像UNet这样的语义分割算法,因为输入图像的大小不同,因此需要FCN。
并将AnimalsClassification示例的网络从此链接更改为UNet
但得到错误。你能帮我解决这个错误吗?
错误:
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
19:22:39,951 INFO ~ Load data....
19:22:40,696 INFO ~ Build model....
19:22:41,363 INFO ~ Loaded [CpuBackend] backend
19:22:44,955 INFO ~ Number of threads used for NativeOps: 2
19:22:45,555 INFO ~ Number of threads used for BLAS: 2
19:22:45,562 INFO ~ Backend used: [CPU]; OS: [Linux]
19:22:45,562 INFO ~ Cores: [2]; Memory: [1.3GB];
19:22:45,562 INFO ~ Blas vendor: [OPENBLAS]
19:22:56,425 WARN ~ Layer "Layer not named" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.
Exception in thread "main" java.lang.IllegalStateException: Invalid configuration: network has no inputs. Use .addInputs(String...) to label (and give an ordering to) the network inputs
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration.validate(ComputationGraphConfiguration.java:279)
at org.deeplearning4j.nn.conf.ComputationGraphConfiguration$GraphBuilder.build(ComputationGraphConfiguration.java:918)
at org.deeplearning4j.examples.convolution.AnimalsClassification.graphBuilder(AnimalsClassification.java:443)
at org.deeplearning4j.examples.convolution.AnimalsClassification.run(AnimalsClassification.java:145)
at org.deeplearning4j.examples.convolution.AnimalsClassification.main(AnimalsClassification.java:447)
Process finished with exit code 1
我更改的代码是:
package org.deeplearning4j.examples.convolution;
import lombok.Builder;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.StepSchedule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.List;
import java.util.Random;
import static java.lang.Math.toIntExact;
/**
* Animal Classification
*
* Example classification of photos from 4 different animals (bear, duck, deer, turtle).
*
* References:
* - U.S. Fish and Wildlife Service (animal sample dataset): http://digitalmedia.fws.gov/cdm/
* - Tiny ImageNet Classification with CNN: http://cs231n.stanford.edu/reports/2015/pdfs/leonyao_final.pdf
*
* CHALLENGE: Current setup gets low score results. Can you improve the scores? Some approaches:
* - Add additional images to the dataset
* - Apply more transforms to dataset
* - Increase epochs
* - Try different model configurations
* - Tune by adjusting learning rate, updaters, activation & loss functions, regularization, ...
*/
public class AnimalsClassification {
protected static final Logger log = LoggerFactory.getLogger(AnimalsClassification.class);
protected static int height = 100;
protected static int width = 100;
protected static int channels = 3;
protected static int batchSize = 20;
// protected static long seed = 42;
private static long seed = 1234;
protected static Random rng = new Random(seed);
protected static int epochs = 50;
protected static double splitTrainTest = 0.8;
protected static boolean save = false;
protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out
private int numLabels;
public void run(String[] args) throws Exception {
log.info("Load data....");
/**cd
* Data Setup -> organize and limit data file paths:
* - mainPath = path to image files
* - fileSplit = define basic dataset split with limits on format
* - pathFilter = define additional file load filter to limit size and balance batch content
**/
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/");
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
int numExamples = toIntExact(fileSplit.length());
numLabels = fileSplit.getRootDir().listFiles(File::isDirectory).length; //This only works if your root is clean: only label subdirs.
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);
/**
* Data Setup -> train test split
* - inputSplit = define train and test split
**/
InputSplit[] inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest);
InputSplit trainData = inputSplit[0];
InputSplit testData = inputSplit[1];
/**
* Data Setup -> transformation
* - Transform = how to tranform images and generate large dataset to train on
**/
// ImageTransform flipTransform1 = new FlipImageTransform(rng);
// ImageTransform flipTransform2 = new FlipImageTransform(new Random(123));
// ImageTransform warpTransform = new WarpImageTransform(rng, 42);
// ImageTransform colorTransform = new ColorConversionTransform(new Random(seed), COLOR_BGR2YCrCb);
// List<ImageTransform> transforms = Arrays.asList(new ImageTransform[]{flipTransform1, warpTransform, flipTransform2});
/**
* Data Setup -> normalization
* - how to normalize images and generate large dataset to train on
**/
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
log.info("Build model....");
// Uncomment below to try AlexNet. Note change height and width to at least 100
// MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();
// MultiLayerNetwork network;
// switch (modelType) {
// case "LeNet":
// network = lenetModel();
// break;
// case "AlexNet":
// network = alexnetModel();
// break;
// case "custom":
// network = customModel();
// break;
// default:
// throw new InvalidInputTypeException("Incorrect model provided.");
// }
ComputationGraph network = new ComputationGraph(graphBuilder());
network.init();
// network.setListeners(new ScoreIterationListener(listenerFreq));
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
network.setListeners(new StatsListener( statsStorage),new ScoreIterationListener(1));
/**
* Data Setup -> define how to load data into net:
* - recordReader = the reader that loads and converts image data pass in inputSplit to initialize
* - dataIter = a generator that only loads one batch at a time into memory to save memory
* - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs
**/
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
DataSetIterator dataIter;
MultipleEpochsIterator trainIter;
log.info("Train model....");
// Train without transformations
recordReader.initialize(trainData, null);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
trainIter = new MultipleEpochsIterator(epochs, dataIter);
network.fit(trainIter);
// Train with transformations
/* for (ImageTransform transform : transforms) {
System.out.print("\nTraining on transformation: " + transform.getClass().toString() + "\n\n");
recordReader.initialize(trainData, transform);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
trainIter = new MultipleEpochsIterator(epochs, dataIter);
network.fit(trainIter);
}*/
log.info("Evaluate model....");
recordReader.initialize(testData);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numLabels);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
Evaluation eval = network.evaluate(dataIter);
log.info(eval.stats(true));
// Example on how to get predict results with trained model. Result for first example in minibatch is printed
dataIter.reset();
DataSet testDataSet = dataIter.next();
List<String> allClassLabels = recordReader.getLabels();
int labelIndex = testDataSet.getLabels().argMax(1).getInt(0);
// int[] predictedClasses = network.predict(testDataSet.getFeatures());
String expectedResult = allClassLabels.get(labelIndex);
// String modelPrediction = allClassLabels.get(predictedClasses[0]);
// System.out.print("\nFor a single example that is labeled " + expectedResult + " the model predicted " + modelPrediction + "\n\n");
if (save) {
log.info("Save model....");
String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "src/main/resources/");
ModelSerializer.writeModel(network, basePath + "model.bin", true);
}
log.info("****************Example finished********************");
}
private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) {
return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build();
}
private ConvolutionLayer conv3x3(String name, int out, double bias) {
return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build();
}
private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) {
return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build();
}
private SubsamplingLayer maxPool(String name, int[] kernel) {
return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build();
}
private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
}
public MultiLayerNetwork lenetModel() {
/**
* Revisde Lenet Model approach developed by ramgo2 achieves slightly above random
* Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5
**/
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.005)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.0001,0.9))
.list()
.layer(0, convInit("cnn1", channels, 50 , new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0))
.layer(1, maxPool("maxpool1", new int[]{2,2}))
.layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0))
.layer(3, maxPool("maxool2", new int[]{2,2}))
.layer(4, new DenseLayer.Builder().nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.build())
.backprop(true).pretrain(false)
.setInputType(InputType.convolutional(height, width, channels))
.build();
return new MultiLayerNetwork(conf);
}
public MultiLayerNetwork alexnetModel() {
/**
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
* and the imagenetExample code referenced.
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
**/
double nonZeroBias = 1;
double dropOut = 0.5;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(0.0, 0.01))
.activation(Activation.RELU)
.updater(new Nesterovs(new StepSchedule(ScheduleType.ITERATION, 1e-2, 0.1, 100000), 0.9))
.biasUpdater(new Nesterovs(new StepSchedule(ScheduleType.ITERATION, 2e-2, 0.1, 100000), 0.9))
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
.l2(5 * 1e-4)
.list()
.layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0))
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
.layer(2, maxPool("maxpool1", new int[]{3,3}))
.layer(3, conv5x5("cnn2", 256, new int[] {1,1}, new int[] {2,2}, nonZeroBias))
.layer(4, new LocalResponseNormalization.Builder().name("lrn2").build())
.layer(5, maxPool("maxpool2", new int[]{3,3}))
.layer(6,conv3x3("cnn3", 384, 0))
.layer(7,conv3x3("cnn4", 384, nonZeroBias))
.layer(8,conv3x3("cnn5", 256, nonZeroBias))
.layer(9, maxPool("maxpool3", new int[]{3,3}))
.layer(10, fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
.layer(11, fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
.layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.name("output")
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.build())
.backprop(true)
.pretrain(false)
.setInputType(InputType.convolutional(height, width, channels))
.build();
return new MultiLayerNetwork(conf);
}
public static MultiLayerNetwork customModel() {
/**
* Use this method to build your own custom model.
**/
return null;
}
@Builder.Default private int[] inputShape = new int[] {3, 512, 512};
@Builder.Default private int numClasses = 0;
@Builder.Default private WeightInit weightInit = WeightInit.RELU;
@Builder.Default private IUpdater updater = new AdaDelta();
@Builder.Default private CacheMode cacheMode = CacheMode.NONE;
// @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
@Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.SINGLE;
@Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
public ComputationGraphConfiguration graphBuilder() {
ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(updater)
.weightInit(weightInit)
.dist(new TruncatedNormalDistribution(0.0, 0.5))
.l2(5e-5)
.miniBatch(true)
.cacheMode(cacheMode)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.graphBuilder();
graph
.addLayer("conv1-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "input")
.addLayer("conv1-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv1-1")
.addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
.build(), "conv1-2")
.addLayer("conv2-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool1")
.addLayer("conv2-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv2-1")
.addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
.build(), "conv2-2")
.addLayer("conv3-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool2")
.addLayer("conv3-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv3-1")
.addLayer("pool3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
.build(), "conv3-2")
.addLayer("conv4-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool3")
.addLayer("conv4-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv4-1")
.addLayer("drop4", new DropoutLayer.Builder(0.5).build(), "conv4-2")
.addLayer("pool4", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2)
.build(), "drop4")
.addLayer("conv5-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "pool4")
.addLayer("conv5-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1024)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv5-1")
.addLayer("drop5", new DropoutLayer.Builder(0.5).build(), "conv5-2")
// up6
.addLayer("up6-1", new Upsampling2D.Builder(2).build(), "drop5")
.addLayer("up6-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up6-1")
.addVertex("merge6", new MergeVertex(), "drop4", "up6-2")
.addLayer("conv6-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge6")
.addLayer("conv6-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv6-1")
// up7
.addLayer("up7-1", new Upsampling2D.Builder(2).build(), "conv6-2")
.addLayer("up7-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up7-1")
.addVertex("merge7", new MergeVertex(), "conv3-2", "up7-2")
.addLayer("conv7-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge7")
.addLayer("conv7-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv7-1")
// up8
.addLayer("up8-1", new Upsampling2D.Builder(2).build(), "conv7-2")
.addLayer("up8-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up8-1")
.addVertex("merge8", new MergeVertex(), "conv2-2", "up8-2")
.addLayer("conv8-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge8")
.addLayer("conv8-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv8-1")
// up9
.addLayer("up9-1", new Upsampling2D.Builder(2).build(), "conv8-2")
.addLayer("up9-2", new ConvolutionLayer.Builder(2,2).stride(1,1).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "up8-1")
.addVertex("merge9", new MergeVertex(), "conv1-2", "up9-2")
.addLayer("conv9-1", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "merge9")
.addLayer("conv9-2", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv9-1")
.addLayer("conv9-3", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(2)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), "conv9-2")
.addLayer("conv10", new ConvolutionLayer.Builder(3,3).stride(1,1).nOut(1)
.convolutionMode(ConvolutionMode.Truncate).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.SIGMOID).build(), "conv9-3")
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.MCXENT).build(), "conv10")
.setOutputs("output").backprop(true).pretrain(false);
return graph.build();
}
public static void main(String[] args) throws Exception {
new AnimalsClassification().run(args);
}
}
非常感谢。