8000 Add Storage.from_buffer by colesbury · Pull Request #9 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add Storage.from_buffer #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? 8000 Sign in to your account

Merged
merged 1 commit into from
Sep 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def run(self):

include_dirs = []
extra_link_args = []
extra_compile_args = ['-std=c++11']
extra_compile_args = ['-std=c++11', '-Wno-write-strings']

cwd = os.path.dirname(os.path.abspath(__file__))
lib_path = os.path.join(cwd, "torch", "lib")
Expand All @@ -138,6 +138,7 @@ def run(self):
"torch/csrc/Generator.cpp",
"torch/csrc/Tensor.cpp",
"torch/csrc/Storage.cpp",
"torch/csrc/byte_order.cpp",
"torch/csrc/utils.cpp",
"torch/csrc/allocators.cpp",
"torch/csrc/serialization.cpp",
Expand Down
14 changes: 14 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,5 +2238,19 @@ def test_serialization(self):
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)

def test_from_buffer(self):
a = bytearray([1, 2, 3, 4])
self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
shorts = torch.ShortStorage.from_buffer(a, 'big')
self.assertEqual(shorts.size(), 2)
self.assertEqual(shorts.tolist(), [258, 772])
ints = torch.IntStorage.from_buffer(a, 'little')
self.assertEqual(ints.size(), 1)
self.assertEqual(ints[0], 67305985)
f = bytearray([0x40, 0x10, 0x00, 0x00])
floats = torch.FloatStorage.from_buffer(f, 'big')
self.assertEqual(floats.size(), 1)
self.assertEqual(floats[0], 2.25)

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion torch/csrc/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
#include <TH/TH.h>
#include <libshm.h>
#include "THP.h"
#include "byte_order.h"

#include "generic/Storage.cpp"
#include <TH/THGenerateAllTypes.h>

#include "generic/StorageCopy.cpp"
#include <TH/THGenerateAllTypes.h>

81 changes: 81 additions & 0 deletions torch/csrc/byte_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include "byte_order.h"

static inline uint16_t decodeUInt16LE(const uint8_t *data) {
return (data[0]<<0) | (data[1]<<8);
}

static inline uint16_t decodeUInt16BE(const uint8_t *data) {
return (data[1]<<0) | (data[0]<<8);
}

static inline uint32_t decodeUInt32LE(const uint8_t *data) {
return (data[0]<<0) | (data[1]<<8) | (data[2]<<16) | (data[3]<<24);
}

static inline uint32_t decodeUInt32BE(const uint8_t *data) {
return (data[3]<<0) | (data[2]<<8) | (data[1]<<16) | (data[0]<<24);
}

static inline uint64_t decodeUInt64LE(const uint8_t *data) {
return (((uint64_t)data[0])<< 0) | (((uint64_t)data[1])<< 8) |
(((uint64_t)data[2])<<16) | (((uint64_t)data[3])<<24) |
(((uint64_t)data[4])<<32) | (((uint64_t)data[5])<<40) |
(((uint64_t)data[6])<<48) | (((uint64_t)data[7])<<56);
}

static inline uint64_t decodeUInt64BE(const uint8_t *data) {
return (((uint64_t)data[7])<< 0) | (((uint64_t)data[7])<< 8) |
(((uint64_t)data[5])<<16) | (((uint64_t)data[4])<<24) |
(((uint64_t)data[3])<<32) | (((uint64_t)data[2])<<40) |
(((uint64_t)data[1])<<48) | (((uint64_t)data[0])<<56);
}

THPByteOrder THP_nativeByteOrder()
{
uint32_t x = 1;
return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
}

void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
dst[i] = (int16_t) (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
src += sizeof(int16_t);
}
}

void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
dst[i] = (int32_t) (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
src += sizeof(int32_t);
}
}

void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
dst[i] = (int64_t) (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
src += sizeof(int64_t);
}
}

void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint32_t x; float f; };
x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
dst[i] = f;
src += sizeof(float);
}
}

void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
union { uint64_t x; double d; };
x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
dst[i] = d;
src += sizeof(double);
}
}
14 changes: 14 additions & 0 deletions torch/csrc/byte_order.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <stdint.h>
#include <stddef.h>

enum THPByteOrder {
THP_LITTLE_ENDIAN = 0,
THP_BIG_ENDIAN = 1
};

THPByteOrder THP_nativeByteOrder();
void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
85 changes: 85 additions & 0 deletions torch/csrc/generic/StorageMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,90 @@ static PyObject * THPStorage_(fill_)(THPStorage *self, PyObject *number_arg)
END_HANDLE_TH_ERRORS
}

#ifndef THC_GENERIC_FILE
static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyObject *keywds)
{
HANDLE_TH_ERRORS
PyObject *obj = NULL;
const char* byte_order_str = NULL;
Py_ssize_t count = -1, offset = 0;
Py_buffer buffer;
static char *kwlist[] = {"buffer", "byte_order", "count", "offset", NULL};
const char* argtypes;
#if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
argtypes = "O|snn";
#else
argtypes = "Os|nn";
#endif

if (!PyArg_ParseTupleAndKeywords(args, keywds, argtypes, kwlist,
&obj, &byte_order_str, &count, &offset)) {
return NULL;
}

#if !(defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR))
THPByteOrder byte_order;
if (strcmp(byte_order_str, "native") == 0) {
byte_order = THP_nativeByteOrder();
} else if (strcmp(byte_order_str, "big") == 0) {
byte_order = THP_BIG_ENDIAN;
} else if (strcmp(byte_order_str, "little") == 0) {
byte_order = THP_LITTLE_ENDIAN;
} else {
PyErr_Format(PyExc_ValueError,
"invalid byte_order '%s' (expected 'big', 'little', or 'native')",
byte_order_str);
return NULL;
}
#endif

if (PyObject_GetBuffer(obj, &buffer, PyBUF_SIMPLE) < 0)
return NULL;

if (offset < 0 || offset > buffer.len) {
PyErr_Format(PyExc_ValueError,
"offset must be non-negative and no greater than buffer length (%ld)",
(long) buffer.len);
return NULL;
}

if (count < 0) {
if ((buffer.len - offset) % sizeof(real) != 0) {
PyErr_Format(PyExc_ValueError, "buffer size must be a multiple of element size");
return NULL;
}
count = (buffer.len - offset) / sizeof(real);
}

if (offset + (count * (Py_ssize_t)sizeof(real)) > buffer.len) {
PyErr_Format(PyExc_ValueError, "buffer is smaller than requested size");
return NULL;
}

uint8_t* src = (uint8_t*) buffer.buf;
THStorage* storage = THStorage_(newWithSize)(count);

#if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
memcpy(storage->data, src + offset, count);
#elif defined(TH_REAL_IS_SHORT)
THP_decodeInt16Buffer(storage->data, src + offset, byte_order, count);
#elif defined(TH_REAL_IS_INT)
THP_decodeInt32Buffer(storage->data, src + offset, byte_order, count);
#elif defined(TH_REAL_IS_LONG)
THP_decodeInt64Buffer(storage->data, src + offset, byte_order, count);
#elif defined(TH_REAL_IS_FLOAT)
THP_decodeFloatBuffer(storage->data, src + offset, byte_order, count);
#elif defined(TH_REAL_IS_DOUBLE)
THP_decodeDoubleBuffer(storage->data, src + offset, byte_order, count);
#else
#error "Unknown type"
#endif

return (PyObject*)THPStorage_(newObject)(storage);
END_HANDLE_TH_ERRORS
}
#endif

PyObject * THPStorage_(writeFile)(THPStorage *self, PyObject *file)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -185,6 +269,7 @@ static PyMethodDef THPStorage_(methods)[] = {
{"_write_file", (PyCFunction)THPStorage_(writeFile), METH_O, NULL},
{"_new_with_file", (PyCFunction)THPStorage_(newWithFile), METH_O | METH_STATIC, NULL},
#ifndef THC_GENERIC_FILE
{"from_buffer", (PyCFunction)THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_share", (PyCFunction)THPStorage_(_share), METH_NOARGS, NULL},
{"_new_shared", (PyCFunction)THPStorage_(_newShared), METH_VARARGS | METH_STATIC, NULL},
#endif
Expand Down
0