8000 Initial utils implementation + bug fixes by apaszke · Pull Request #12 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Initial utils implementation + bug fixes #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8000 Merged
merged 3 commits into from
Sep 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ def __call__(self, test_case):
module = self.constructor(*self.constructor_args)
input = self._get_input()

# Check that these methods don't raise errors
module.__repr__()
str(module)

if self.reference_fn is not None:
out = test_case._forward_criterion(module, input, self.target)
expected_out = self.reference_fn(deepcopy(self._unpack_input(input)),
Expand Down
1 change: 1 addition & 0 deletions test/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ python test_autograd.py
python test_nn.py
python test_legacy_nn.py
python test_multiprocessing.py
python test_utils.py
if which nvcc >/dev/null 2>&1
then
python test_cuda.py
Expand Down
92 changes: 78 additions & 14 deletions test/test_legacy_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,11 @@ def _do_test(self, test_case, module, input):
OldModuleTest(nn.Squeeze,
input_size=(2, 1, 1, 4, 5),
reference_fn=lambda i,_: i.squeeze()),
# TODO: should squeeze work inplace?
# OldModuleTest(nn.Squeeze,
# (1,),
# input_size=(2, 1, 1, 4, 5),
# reference_fn=lambda i,_: i.squeeze(1),
# desc='dim'),
OldModuleTest(nn.Squeeze,
(1,),
input_size=(2, 1, 1, 4, 5),
reference_fn=lambda i,_: i.squeeze(1),
desc='dim'),
OldModuleTest(nn.Unsqueeze,
(1,),
input_size=(2, 4, 5),
Expand Down Expand Up @@ -356,13 +355,12 @@ def _do_test(self, test_case, module, input):
desc='stride_pad'),
OldModuleTest(nn.SpatialAdaptiveMaxPooling,
(4, 4),
input_size=(2, 3, 8, 8)),
input_size=(2, 3, 8, 8),
reference_fn=lambda i,_: nn.SpatialMaxPooling(2, 2).forward(i)),
OldModuleTest(nn.SpatialAdaptiveMaxPooling,
(4, 4),
input_size=(2, 3, 7, 11),
desc='irregular'),
# TODO: enable after implementing MaxPooling
# reference_fn=lambda i,_: nn.SpatialMaxPooling(2, 2).forward(i)),
OldModuleTest(nn.SpatialConvolution,
(3, 4, 3, 3),
input_size=(2, 3, 6, 6)),
Expand All @@ -385,10 +383,9 @@ def _do_test(self, test_case, module, input):
(3, 2, 6, 6, 2, 2, 2, 2, 1, 1),
input_size=(2, 3, 6, 6),
desc='stride_pad'),
# TODO FIX THIS
# OldModuleTest(nn.SpatialCrossMapLRN,
# (3,),
# input_size=(2, 3, 6, 6)),
OldModuleTest(nn.SpatialCrossMapLRN,
(5, 5e-3, 1e-3, 2),
input_size=(2, 3, 6, 6)),
OldModuleTest(nn.SpatialDivisiveNormalization,
(3,),
input_size=(2, 3, 8, 8)),
Expand Down Expand Up @@ -750,6 +747,10 @@ def test_Dropout(self):
gradInput = module.backward(input.clone(), input.clone())
self.assertLess(abs(gradInput.mean() - (1-p)), 0.05)

# Check that these don't raise errors
module.__repr__()
str(module)

def test_SpatialDropout(self):
p = 0.2
b = random.randint(1, 5)
Expand All @@ -764,6 +765,10 @@ def test_SpatialDropout(self):
gradInput = module.backward(input, input)
self.assertLess(abs(gradInput.mean() - (1-p)), 0.05)

# Check that these don't raise errors
module.__repr__()
str(module)

def test_VolumetricDropout(self):
p = 0.2
bsz = random.randint(1,5)
Expand All @@ -779,6 +784,10 @@ def test_VolumetricDropout(self):
gradInput = module.backward(input, input)
self.assertLess(abs(gradInput.mean() - (1-p)), 0.05)

# Check that these don't raise errors
module.__repr__()
str(module)

def test_ReLU_reference(self):
input = torch.randn(10, 20)
module = nn.ReLU()
Expand All @@ -790,7 +799,6 @@ def test_ReLU6_reference(self):
input = torch.randn(10, 20).mul(10)
module = nn.ReLU6()
output = module.forward(input)
# TODO: check elements between 0 and 6
self.assertTrue(output[input.ge(6)].eq(6).all())
self.assertTrue(output[input.lt(0)].eq(0).all())

Expand All @@ -807,6 +815,10 @@ def test_Copy(self):
c.double()
self.assertEqual(torch.typename(output), 'torch.FloatTensor')

# Check that these don't raise errors
c.__repr__()
str(c)

def test_FlattenTable(self):
input = [
torch.rand(1),
Expand Down Expand Up @@ -839,6 +851,10 @@ def test_FlattenTable(self):
self.assertEqual(gradOutput[2], gradInput[1][1][0])
self.assertEqual(gradOutput[3], gradInput[2])

# Check that these don't raise errors
m.__repr__()
str(m)

# More uglyness: FlattenTable doesn't rebuild the table every updateOutput
# call, so we need to make sure that modifications to the input are
# detected correctly (and that the table is correctly rebuilt.
Expand Down Expand Up @@ -877,6 +893,10 @@ def test_Concat(self):
l.weight.fill_(1)
l.bias.fill_(0)

# Check that these don't raise errors
m.__repr__()
str(m)

output = m.forward(input)
output2 = input.sum(1).expand(4, 5).repeatTensor(num_modules, 1)
self.assertEqual(output2, output)
Expand All @@ -896,6 +916,10 @@ def test_Parallel(self):
m.add(nn.View(4, 5, 1))
m.add(nn.View(4, 5, 1))

# Check that these don't raise errors
m.__repr__()
str(m)

output = m.forward(input)
output2 = input.transpose(0, 2).transpose(0, 1)
self.assertEqual(output2, output)
Expand All @@ -914,6 +938,10 @@ def test_ParallelTable(self):
m.add(p)
m.add(nn.JoinTable(2))

# Check that these don't raise errors
p.__repr__()
str(p)

output = m.forward(input)
output2 = input.transpose(0,2).transpose(0,1)
self.assertEqual(output2, output)
Expand All @@ -939,6 +967,10 @@ def test_ConcatTable(self):
module.add(nn.Identity())
module.float()

# Check that these don't raise errors
module.__repr__()
str(module)

output = module.forward(input)
output2 = [input, input, input]
self.assertEqual(output2, output)
Expand Down Expand Up @@ -998,6 +1030,10 @@ def test_DepthConcat(self):
self.assertEqual(output, outputConcat)
self.assertEqual(gradInput, gradInputConcat)

# Check that these don't raise errors
concat.__repr__()
str(concat)

def test_Contiguous(self):
input = torch.randn(10, 10, 10)
noncontig = input[:, 4]
Expand All @@ -1007,6 +1043,10 @@ def test_Contiguous(self):
self.assertEqual(output, noncontig)
self.assertTrue(output.contiguous())

# Check that these don't raise errors
module.__repr__()
str(module)

def test_Index(self):
net = nn.Index(0)

Expand All @@ -1028,6 +1068,10 @@ def test_Index(self):
gradInput = net.backward(input, gradOutput)
self.assertEqual(gradInput[0], torch.Tensor(((2, 4), (0, 0))))

# Check that these don't raise errors
net.__repr__()
str(net)

def test_L1Penalty(self):
weight = 1
m = nn.L1Penalty(weight, False, False)
Expand All @@ -1044,6 +1088,10 @@ def test_L1Penalty(self):
input.lt(0).typeAs(grad).mul_(-1)).mul_(weight)
self.assertEqual(true_grad, grad)

# Check that these don't raise errors
m.__repr__()
str(m)

def test_MaskedSelect(self):
input = torch.randn(4, 5)
mask = torch.ByteTensor(4, 5).bernoulli_()
Expand All @@ -1060,6 +1108,10 @@ def test_MaskedSelect(self):
gradIn = module.backward([input, mask], gradOut)
self.assertEqual(inTarget, gradIn[0])

# Check that these don't raise errors
module.__repr__()
str(module)

def test_MultiCriterion(self):
input = torch.rand(2, 10)
target = torch.Tensor((1, 8))
Expand All @@ -1085,6 +1137,10 @@ def test_MultiCriterion(self):
self.assertEqual(output, output3)
self.assertEqual(gradInput.float(), gradInput3)

# Check that these don't raise errors
mc.__repr__()
str(mc)

# test table input
# TODO: enable when Criterion.clone is ready
# mc.double()
Expand Down Expand Up @@ -1164,6 +1220,10 @@ def test_ParallelCriterion(self):
self.assertEqual(gradInput[1][0], gradInput2[1][0])
self.assertEqual(gradInput[1][1], gradInput2[1][1])

# Check that these don't raise errors
pc.__repr__()
str(pc)

def test_NarrowTable(self):
input = [torch.Tensor(i) for i in range(1, 6)]

Expand All @@ -1175,6 +1235,10 @@ def test_NarrowTable(self):
output = module.forward(input)
self.assertEqual(output, input[2:5])

# Check that these don't raise errors
module.__repr__()
str(module)


if __name__ == '__main__':
prepare_tests()
Expand Down
14 changes: 13 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import tempfile
import unittest
from itertools import product
from itertools import product, chain
from common import TestCase, iter_indices

SIZE = 100
Expand Down Expand Up @@ -2252,5 +2252,17 @@ def test_from_buffer(self):
self.assertEqual(floats.size(), 1)
self.assertEqual(floats[0], 2.25)

def test_print(self):
for t in torch._tensor_classes:
obj = t(100, 100).fill_(1)
obj.__repr__()
str(obj)
for t in torch._storage_classes:
obj = t(100).fill_(1)
obj.__repr__()
str(obj)


if __name__ == '__main__':
unittest.main()

Loading
0