Font Transfer 2 Autoencoders
Font Transfer 2 Autoencoders
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.transform import rotate
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader
from torch import nn
import torch.optim as optim
from PIL import Image, ImageOps
True
2) Load Dataset
# load csv dataset
data = pd.read_csv('Alphabets/dataset/datalabels.csv')
data.head()
train_x = train_font1
train_y = train_font2
len(train_x), len(train_y)
(89, 89)
3) Data Augmentation
# gaussian noise
def add_noise(inputs,noise_factor=0.3):
noisy = inputs+torch.randn_like(inputs) * noise_factor
noisy = torch.clip(noisy,0.,1.)
return noisy
# rotation
def rotate(input, degree):
rotated = TF.rotate(input, degree)
return rotated
# resize
def resize(size, input):
resized = T.Resize(size) (input)
return resized
# gaussian blurr
def blurr(ksize, sigma,input):
blurred = T.GaussianBlur(ksize,sigma)(input)
return blurred
# center crop
def crop(size,input):
cropped = T.CenterCrop(size)(input)
cropped=resize(size=(28,28),input=cropped)
return cropped
# random blocks
def add_random_boxes(img,n_k,size=2):
h,w = size,size
img = np.asarray(img)
img_size = img.shape[1]
boxes = []
for k in range(n_k):
y,x = np.random.randint(0,img_size-w,(2,))
img[y:y+h,x:x+w] = 0
boxes.append((x,y,h,w))
img = Image.fromarray(img.astype('uint8'), 'RGB')
return img
#%%
final_train_data = []
final_target_data = []
size=(24,24)
for i in range(0,26):
final_train_data.append(train_x[i])
final_train_data.append(rotate(train_x[i],90))
final_train_data.append(rotate(train_x[i],180))
final_train_data.append(rotate(train_x[i],270))
#final_train_data.append(add_noise(train_x[i],0.3)) #Gaussian
noise
##final_train_data.append(resize(size,train_x[i])) #Resize
#final_train_data.append(blurr(ksize=3,
sigma=0.45,input=train_x[i])) #Blur
#final_train_data.append(crop(size=24,input=train_x[i])) #Crop
for i in range(0,26):
final_target_data.append(train_y[i])
final_target_data.append(rotate(train_y[i],90))
final_target_data.append(rotate(train_y[i],180))
final_target_data.append(rotate(train_y[i],270))
#%%
# Data augmentation
transform_train= transforms.Compose([
train_dataset=[]
target_dataset=[]
for i in range(0,len(final_train_data)):
transformed_dataset = transform_train(final_train_data[i])
transformed_targetset = transform_train(final_target_data[i])
train_dataset.append(transformed_dataset)
target_dataset.append(transformed_targetset)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=16, shuffle=True)
target_loader = torch.utils.data.DataLoader(dataset=target_dataset,
batch_size=16, shuffle=True)
class Decoder(nn.Module):
self.unflatten = nn.Unflatten(dim=1,
unflattened_size=(64, 3, 3))
self.decoder_conv = nn.Sequential(
nn.ConvTranspose2d(64, 7, 3,
stride=2, output_padding=0),
nn.BatchNorm2d(7),
nn.ReLU(True),
nn.ConvTranspose2d(7, 14, 3, stride=2,
padding=1, output_padding=1),
nn.BatchNorm2d(14),
nn.ReLU(True),
nn.ConvTranspose2d(14, 1, 3, stride=2,
padding=1, output_padding=1)
)
### Define an optimizer (both for the encoder and the decoder!)
lr= 0.001
encoder = Encoder(encoded_space_dim=d,fc2_input_dim=128)
decoder = Decoder(encoded_space_dim=d,fc2_input_dim=128)
params_to_optimize = [
{'params': encoder.parameters()},
{'params': decoder.parameters()}
]
6) Model training
# Check if the GPU is available
device = torch.device("cuda") if torch.cuda.is_available() else
torch.device("cpu")
print(f'Selected device: {device}')
# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)
#%%
Decoder(
(decoder_lin): Sequential(
(0): Linear(in_features=128, out_features=128, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=128, out_features=576, bias=True)
(3): ReLU(inplace=True)
)
(unflatten): Unflatten(dim=1, unflattened_size=(64, 3, 3))
(decoder_conv): Sequential(
(0): ConvTranspose2d(64, 7, kernel_size=(3, 3), stride=(2, 2))
(1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(7, 14, kernel_size=(3, 3), stride=(2, 2),
padding=(1, 1), output_padding=(1, 1))
(4): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(14, 1, kernel_size=(3, 3), stride=(2, 2),
padding=(1, 1), output_padding=(1, 1))
)
)
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_fn,
optimizer):
# Set train mode for both the encoder and the decoder
encoder.train()
decoder.train()
train_loss = []
# Iterate the dataloader (we do not need the label values, this is
unsupervised learning)
for image_batch in dataloader: # with "_" we just ignore the
labels (the second element of the dataloader tuple)
# Move tensor to the proper device
image_batch = image_batch.to(device)
# Encode data
encoded_data = encoder(image_batch)
# Decode data
decoded_data = decoder(encoded_data)
# Evaluate loss
loss = loss_fn(decoded_data, image_batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print batch loss
#print('\t partial train loss (single batch): %f' %
(loss.data))
train_loss.append(loss.detach().cpu().numpy())
return np.mean(train_loss)
#%%
### Testing function
def plot_targets(encoder,decoder,n=10):
plt.figure(figsize=(16,16))
for m in range (0,len(final_target_data)):
targets = final_target_data[m].numpy()
#t_idx = {0:5}
for i in range(n):
ax = plt.subplot(2,2,1)
img = final_target_data[m].unsqueeze(0).to(device)
encoder.eval()
decoder.eval()
with torch.no_grad():
rec_img = decoder(encoder(img))
plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Font 2')
ax = plt.subplot(1, 2, 2)
plt.imshow(rec_img.cpu().squeeze().numpy(),
cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Reconstructed')
plt.show()
def plot_train(encoder,decoder,n=10):
plt.figure(figsize=(16,16))
for m in range (0,len(final_train_data)):
targets = final_train_data[m].numpy()
for i in range(n):
ax = plt.subplot(2,2,1)
img = final_train_data[m].unsqueeze(0).to(device)
encoder.eval()
decoder.eval()
with torch.no_grad():
rec_img = decoder(encoder(img))
plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Font 1')
ax = plt.subplot(1, 2, 2)
plt.imshow(rec_img.cpu().squeeze().numpy(),
cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Reconstructed')
plt.show()
def convert(encoder,decoder,n=10):
plt.figure(figsize=(16,16))
for m in range (0,len(final_train_data)):
targets = final_target_data[m].numpy()
for i in range(n):
ax = plt.subplot(2,2,1)
img = final_train_data[m].unsqueeze(0).to(device)
encoder.eval()
decoder.eval()
with torch.no_grad():
rec_img = decoder(encoder(img))
plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Font 1')
ax = plt.subplot(1, 2, 2)
plt.imshow(rec_img.cpu().squeeze().numpy(),
cmap='gist_gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == n//2:
ax.set_title('Font 2')
plt.show()
#%% Reconstruction of Font 1
num_epochs =1000
diz_loss = {'train_loss':[],'val_loss':[]}
for epoch in range(num_epochs):
train_loss =train_epoch(encoder,decoder,device,
train_loader,loss_fn,optim)
val_loss = test_epoch(encoder,decoder,device,train_loader,loss_fn)
if (epoch%100==0):
print('\n EPOCH {}/{} \t train loss {} '.format(epoch + 1,
num_epochs,train_loss))
print('\n EPOCH {}/{} \t val loss {}'.format(epoch + 1,
num_epochs,val_loss))
diz_loss['train_loss'].append(train_loss)
plt.figure(figsize=(10,8))
plt.semilogy(diz_loss['train_loss'], label='Train Loss')
plt.semilogy(diz_loss['val_loss'], label='Convert Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
#plt.grid()
plt.legend()
# plt.ylim(10e-2, 0.2)
#plt.title('loss')
plt.show()
plot_train(encoder,decoder,n=10)
#plot_targets(encoder,decoder,n=10)
----------------------------------------------------------------------
-----
KeyboardInterrupt Traceback (most recent call
last)
~\AppData\Local\Temp\ipykernel_2016\2962927840.py in <module>
----> 1 plot_targets(encoder,decoder,n=10)
~\AppData\Local\Temp\ipykernel_2016\1435074076.py in
plot_targets(encoder, decoder, n)
77 if i == n//2:
78 ax.set_title('Reconstructed')
---> 79 plt.show()
80
81 #%% Plot reconstruction of Font 1
~\anaconda3\lib\site-packages\matplotlib\pyplot.py in show(*args,
**kwargs)
387 """
388 _warn_if_gui_out_of_main_thread()
--> 389 return _get_backend_mod().show(*args, **kwargs)
390
391
~\anaconda3\lib\site-packages\matplotlib_inline\backend_inline.py in
show(close, block)
88 try:
89 for figure_manager in Gcf.get_all_fig_managers():
---> 90 display(
91 figure_manager.canvas.figure,
92
metadata=_fetch_figure_metadata(figure_manager.canvas.figure)
~\anaconda3\lib\site-packages\IPython\core\display.py in
display(include, exclude, metadata, transient, display_id, *objs,
**kwargs)
318 publish_display_data(data=obj, metadata=metadata,
**kwargs)
319 else:
--> 320 format_dict, md_dict = format(obj,
include=include, exclude=exclude)
321 if not format_dict:
322 # nothing to display (e.g. _ipython_display_
took over)
~\anaconda3\lib\site-packages\IPython\core\formatters.py in
format(self, obj, include, exclude)
178 md = None
179 try:
--> 180 data = formatter(obj)
181 except:
182 # FIXME: log the exception
~\anaconda3\lib\site-packages\IPython\core\formatters.py in
catch_format_error(method, self, *args, **kwargs)
222 """show traceback on failed format call"""
223 try:
--> 224 r = method(self, *args, **kwargs)
225 except NotImplementedError:
226 # don't warn on NotImplementedErrors
~\anaconda3\lib\site-packages\IPython\core\formatters.py in
__call__(self, obj)
339 pass
340 else:
--> 341 return printer(obj)
342 # Finally look for special method names
343 method = get_real_method(obj, self.print_method)
~\anaconda3\lib\site-packages\IPython\core\pylabtools.py in
print_figure(fig, fmt, bbox_inches, base64, **kwargs)
149 FigureCanvasBase(fig)
150
--> 151 fig.canvas.print_figure(bytes_io, **kw)
152 data = bytes_io.getvalue()
153 if fmt == 'svg':
~\anaconda3\lib\site-packages\matplotlib\backend_bases.py in
print_figure(self, filename, dpi, facecolor, edgecolor, orientation,
format, bbox_inches, pad_inches, bbox_extra_artists, backend,
**kwargs)
2317 # force the figure dpi to 72), so we need to
set it again here.
2318 with cbook._setattr_cm(self.figure, dpi=dpi):
-> 2319 result = print_method(
2320 filename,
2321 facecolor=facecolor,
~\anaconda3\lib\site-packages\matplotlib\backend_bases.py in
wrapper(*args, **kwargs)
1646 kwargs.pop(arg)
1647
-> 1648 return func(*args, **kwargs)
1649
1650 return wrapper
~\anaconda3\lib\site-packages\matplotlib\_api\deprecation.py in
wrapper(*inner_args, **inner_kwargs)
413 else deprecation_addendum,
414 **kwargs)
--> 415 return func(*inner_args, **inner_kwargs)
416
417 DECORATORS[wrapper] = decorator
~\anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py in
print_png(self, filename_or_obj, metadata, pil_kwargs, *args)
538 *metadata*, including the default 'Software' key.
539 """
--> 540 FigureCanvasAgg.draw(self)
541 mpl.image.imsave(
542 filename_or_obj, self.buffer_rgba(), format="png",
origin="upper",
~\anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py in
draw(self)
434 (self.toolbar._wait_cursor_for_draw_cm() if
self.toolbar
435 else nullcontext()):
--> 436 self.figure.draw(self.renderer)
437 # A GUI class may be need to update a window using
this draw, so
438 # don't forget to call the superclass.
~\anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer, *args, **kwargs)
71 @wraps(draw)
72 def draw_wrapper(artist, renderer, *args, **kwargs):
---> 73 result = draw(artist, renderer, *args, **kwargs)
74 if renderer._rasterizing:
75 renderer.stop_rasterizing()
~\anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer)
48 renderer.start_filter()
49
---> 50 return draw(artist, renderer)
51 finally:
52 if artist.get_agg_filter() is not None:
~\anaconda3\lib\site-packages\matplotlib\figure.py in draw(self,
renderer)
2835
2836 self.patch.draw(renderer)
-> 2837 mimage._draw_list_compositing_images(
2838 renderer, self, artists,
self.suppressComposite)
2839
~\anaconda3\lib\site-packages\matplotlib\image.py in
_draw_list_compositing_images(renderer, parent, artists,
suppress_composite)
130 if not_composite or not has_images:
131 for a in artists:
--> 132 a.draw(renderer)
133 else:
134 # Composite any adjacent images together
~\anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer)
48 renderer.start_filter()
49
---> 50 return draw(artist, renderer)
51 finally:
52 if artist.get_agg_filter() is not None:
~\anaconda3\lib\site-packages\matplotlib\axes\_base.py in draw(self,
renderer)
3089 renderer.stop_rasterizing()
3090
-> 3091 mimage._draw_list_compositing_images(
3092 renderer, self, artists,
self.figure.suppressComposite)
3093
~\anaconda3\lib\site-packages\matplotlib\image.py in
_draw_list_compositing_images(renderer, parent, artists,
suppress_composite)
130 if not_composite or not has_images:
131 for a in artists:
--> 132 a.draw(renderer)
133 else:
134 # Composite any adjacent images together
~\anaconda3\lib\site-packages\matplotlib\artist.py in
draw_wrapper(artist, renderer)
48 renderer.start_filter()
49
---> 50 return draw(artist, renderer)
51 finally:
52 if artist.get_agg_filter() is not None:
~\anaconda3\lib\site-packages\matplotlib\image.py in draw(self,
renderer, *args, **kwargs)
644 renderer.draw_image(gc, l, b, im, trans)
645 else:
--> 646 im, l, b, trans = self.make_image(
647 renderer, renderer.get_image_magnification())
648 if im is not None:
~\anaconda3\lib\site-packages\matplotlib\image.py in make_image(self,
renderer, magnification, unsampled)
954 clip = ((self.get_clip_box() or self.axes.bbox) if
self.get_clip_on()
955 else self.figure.bbox)
--> 956 return self._make_image(self._A, bbox,
transformed_bbox, clip,
957 magnification,
unsampled=unsampled)
958
~\anaconda3\lib\site-packages\matplotlib\image.py in _make_image(self,
A, in_bbox, out_bbox, clip_bbox, magnification, unsampled,
round_to_pixel_border)
346 "this method is called.")
347
--> 348 clipped_bbox = Bbox.intersection(out_bbox, clip_bbox)
349
350 if clipped_bbox is None:
~\anaconda3\lib\site-packages\matplotlib\transforms.py in
intersection(bbox1, bbox2)
677 None if they don't.
678 """
--> 679 x0 = np.maximum(bbox1.xmin, bbox2.xmin)
680 x1 = np.minimum(bbox1.xmax, bbox2.xmax)
681 y0 = np.maximum(bbox1.ymin, bbox2.ymin)
~\anaconda3\lib\site-packages\matplotlib\transforms.py in xmin(self)
327 def xmin(self):
328 """The left edge of the bounding box."""
--> 329 return np.min(self.get_points()[:, 0])
330
331 @property
~\anaconda3\lib\site-packages\matplotlib\transforms.py in
get_points(self)
1127 # from the result, taking care to make the
orientation the
1128 # same.
-> 1129 points = self._transform.transform(
1130 [[p[0, 0], p[0, 1]],
1131 [p[1, 0], p[0, 1]],
~\anaconda3\lib\site-packages\matplotlib\transforms.py in
transform(self, values)
1501
1502 # Transform the values
-> 1503 res =
self.transform_affine(self.transform_non_affine(values))
1504
1505 # Convert the result back to the shape of the input
values.
~\anaconda3\lib\site-packages\matplotlib\transforms.py in
transform_affine(self, points)
2417 def transform_affine(self, points):
2418 # docstring inherited
-> 2419 return self.get_affine().transform(points)
2420
2421 def transform_non_affine(self, points):
~\anaconda3\lib\site-packages\matplotlib\transforms.py in
get_affine(self)
2444 return self._b.get_affine()
2445 else:
-> 2446 return
Affine2D(np.dot(self._b.get_affine().get_matrix(),
2447
self._a.get_affine().get_matrix()))
2448