8000 [WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT by rflamary · Pull Request #360 · PythonOT/POT · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#### New features

- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360).
- Implementation of factored OT with emd and sinkhorn (PR #358).
- A brand new logo for POT (PR #357)
- Better list of related examples in quick start guide with `minigallery` (PR #334).
Expand Down
168 changes: 168 additions & 0 deletions examples/backends/plot_dual_ot_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
r"""
======================================================================
Dual OT solvers for entropic and quadratic regularized OT with Pytorch
======================================================================


"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import numpy as np
import matplotlib.pyplot as pl
import torch
import ot
import ot.plot

# %%
# Data generation
# ---------------

torch.manual_seed(1)

n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1

Xs, ys = ot.datasets.make_data_classif(
'gaussrot', n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
'gaussrot', n_target_samples, theta=theta, nz=noise_level)

# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4


# %%
# Plot data
# ---------

pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

# %%
# Convert data to torch tensors
# -----------------------------

xs = torch.tensor(Xs)
xt = torch.tensor(Xt)

# %%
# Estimating dual variables for entropic OT
# -----------------------------------------

u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.5

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []

for i in range(n_iter):

# generate noise samples

# minus because we maximize te dual loss
loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
losses.append(float(loss.detach()))

if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))

loss.backward()
optimizer.step()
optimizer.zero_grad()


pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")

Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)

# %%
# Plot teh estimated entropic OT plan
# -----------------------------------

pl.figure(3, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
pl.legend(loc=0)
pl.title('Source and target distributions')


# %%
# Estimating dual variables for quadratic OT
# -----------------------------------------

u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.01

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []


for i in range(n_iter):

# generate noise samples

# minus because we maximize te dual loss
loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
losses.append(float(loss.detach()))

if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))

loss.backward()
optimizer.step()
optimizer.zero_grad()


pl.figure(4)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")

Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)


# %%
# Plot the estimated quadratic OT plan
# -----------------------------------

pl.figure(5, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
pl.legend(loc=0)
pl.title('OT plan with quadratic regularization')
189 changes: 189 additions & 0 deletions examples/backends/plot_stoch_continuous_ot_pytorch.py
9E12
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
r"""
======================================================================
Continuous OT plan estimation with Pytorch
======================================================================


"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import numpy as np
import matplotlib.pyplot as pl
import torch
from torch import nn
import ot
import ot.plot

# %%
# Data generation
# ---------------

torch.manual_seed(42)
np.random.seed(42)

n_source_samples = 10000
n_target_samples = 10000
theta = 2 * np.pi / 20
noise_level = 0.1

Xs = np.random.randn(n_source_samples, 2) * 0.5
Xt = np.random.randn(n_target_samples, 2) * 2

# one of the target mode changes its variance (no linear mapping)
Xt = Xt + 4


# %%
# Plot data
# ---------
nvisu = 300
pl.figure(1, (5, 5))
pl.clf()
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Source and target distributions')

# %%
# Convert data to torch tensors
# -----------------------------

xs = torch.tensor(Xs)
xt = torch.tensor(Xt)

# %%
# Estimating deep dual variables for entropic OT
# ----------------------------------------------

torch.manual_seed(42)

# define the MLP model


class Potential(torch.nn.Module):
def __init__(self):
super(Potential, self).__init__()
self.fc1 = nn.Linear(2, 200)
self.fc2 = nn.Linear(200, 1)
self.relu = torch.nn.ReLU() # instead of Heaviside step fn

def forward(self, x):
output = self.fc1(x)
output = self.relu(output) # instead of Heaviside step fn
output = self.fc2(output)
return output.ravel()


u = Potential().double()
v = Potential().double()

reg = 1

optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)

# number of iteration
n_iter = 1000
n_batch = 500


losses = []

for i in range(n_iter):

# generate noise samples

iperms = torch.randint(0, n_source_samples, (n_batch,))
ipermt = torch.randint(0, n_target_samples, (n_batch,))

xsi = xs[iperms]
xti = xt[ipermt]

# minus because we maximize te dual loss
loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
losses.append(float(loss.detach()))

if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))

loss.backward()
optimizer.step()
optimizer.zero_grad()


pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")


# %%
# Plot the density on arget for a given source sample
# ---------------------------------------------------


nv = 100
xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)

XX, YY = np.meshgrid(xl, yl)

xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)

wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
wxg = wxg / np.sum(wxg)

xg = torch.tensor(xg)
wxg = torch.tensor(wxg)


pl.figure(4, (12, 4))
pl.clf()
pl.subplot(1, 3, 1)

iv = 2
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')

pl.subplot(1, 3, 2)

iv = 3
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')

pl.subplot(1, 3, 3)

iv = 6
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')
Loading
0