Predict new examples by a trained neural net.

# S3 method for class 'nnet'
predict(object, newdata, type = c("raw","class"), ...)

Arguments

object

an object of class nnet as returned by nnet.

newdata

matrix or data frame of test examples. A vector is considered to be a row vector comprising a single case.

type

Type of output

...

arguments passed to or from other methods.

Value

If type = "raw", the matrix of values returned by the trained network; if type = "class", the corresponding class (which is probably only useful if the net was generated by nnet.formula).

Details

This function is a method for the generic function predict() for class "nnet". It can be invoked by calling predict(x) for an object x of the appropriate class, or directly by calling predict.nnet(x) regardless of the class of the object.

References

Ripley, B. D. (1996) Pattern Recognition and Neural Networks. Cambridge.

Venables, W. N. and Ripley, B. D. (2002) Modern Applied Statistics with S. Fourth edition. Springer.

See also

Examples

# use half the iris data
ir <- rbind(iris3[,,1], iris3[,,2], iris3[,,3])
targets <- class.ind( c(rep("s", 50), rep("c", 50), rep("v", 50)) )
samp <- c(sample(1:50,25), sample(51:100,25), sample(101:150,25))
ir1 <- nnet(ir[samp,], targets[samp,],size = 2, rang = 0.1,
            decay = 5e-4, maxit = 200)
#> # weights:  19
#> initial  value 55.872771 
#> iter  10 value 44.559876
#> iter  20 value 25.625047
#> iter  30 value 3.372670
#> iter  40 value 2.585274
#> iter  50 value 2.530046
#> iter  60 value 2.477312
#> iter  70 value 2.467636
#> iter  80 value 2.459026
#> iter  90 value 2.451992
#> iter 100 value 2.450624
#> iter 110 value 2.449916
#> iter 120 value 2.449655
#> iter 130 value 2.449526
#> final  value 2.449515 
#> converged
test.cl <- function(true, pred){
        true <- max.col(true)
        cres <- max.col(pred)
        table(true, cres)
}
test.cl(targets[-samp,], predict(ir1, ir[-samp,]))
#>     cres
#> true  1  2  3
#>    1 23  0  2
#>    2  0 25  0
#>    3  0  0 25

# or
ird <- data.frame(rbind(iris3[,,1], iris3[,,2], iris3[,,3]),
        species = factor(c(rep("s",50), rep("c", 50), rep("v", 50))))
ir.nn2 <- nnet(species ~ ., data = ird, subset = samp, size = 2, rang = 0.1,
               decay = 5e-4, maxit = 200)
#> # weights:  19
#> initial  value 82.804858 
#> iter  10 value 35.075341
#> iter  20 value 34.896083
#> iter  30 value 28.106085
#> iter  40 value 6.010230
#> iter  50 value 4.493900
#> iter  60 value 4.293801
#> iter  70 value 4.258495
#> iter  80 value 4.219140
#> iter  90 value 4.217768
#> iter 100 value 4.214896
#> iter 110 value 4.172050
#> iter 120 value 3.281227
#> iter 130 value 2.447272
#> iter 140 value 2.013296
#> iter 150 value 1.671610
#> iter 160 value 1.339102
#> iter 170 value 0.999793
#> iter 180 value 0.875408
#> iter 190 value 0.831225
#> iter 200 value 0.804169
#> final  value 0.804169 
#> stopped after 200 iterations
table(ird$species[-samp], predict(ir.nn2, ird[-samp,], type = "class"))
#>    
#>      c  s  v
#>   c 24  0  1
#>   s  0 25  0
#>   v  3  0 22