8000 question about training speed · Issue #3 · roatienza/straug · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
question about training speed #3
Open
@littletomatodonkey

Description

@littletomatodonkey

thanks for your excellent job! it seems that the training is very slow when i use the straug(6x times slower than that without straug). What about the real speed when you test? The following is my aug-code.

class RecStraugRandAug(object):
    def __init__(self, num_aug=2, prob=0.5, **kwargs):
        super().__init__()
        self.num_aug = num_aug
        self.prob = prob
        try:
            from straug.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur
            from straug.camera import Contrast, Brightness, JpegCompression, Pixelate
            from straug.geometry import Perspective, Shrink, Rotate
            from straug.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
            from straug.pattern import Grid, VGrid, HGrid, RectGrid, EllipseGrid
            from straug.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
            from straug.warp import Stretch, Distort, Curve
            from straug.weather import Fog, Snow, Frost, Rain, Shadow
            self.augs = [
                [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur()],
                [Contrast(), Brightness(), JpegCompression(), Pixelate()],
                [Perspective(), Shrink(), Rotate()],
                [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()],
                [Grid(), VGrid(), HGrid(), RectGrid(), EllipseGrid()],
                [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()],
                [Stretch(), Distort(), Curve()],
                [Fog(), Snow(), Frost(), Rain(), Shadow()],
            ]
        except Exception as ex:
            print(f"exception: {ex}, you can install straug using `pip install straug`")
            exit(-1)
    
    def __call__(self, data):
        img = Image.fromarray(data["image"])
        for idx in range(self.num_aug):
            aug_type_idx = np.random.randint(0, len(self.augs))
            aug_idx = np.random.randint(0, len(self.augs[aug_type_idx]))
            img = self.augs[aug_type_idx][aug_idx](img, mag=random.randint(-1,2), prob=self.prob)
        data["image"] = np.array(img)
        return data

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0