Open In Colab

My_DETR

from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)
hidden_dim = 256

query_pos = nn.Parameter(torch.rand(100, hidden_dim))
row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

print(col_embed.shape)
print(query_pos.shape)
torch.Size([50, 128])
torch.Size([100, 256])
H, W = 20, 20
pos = torch.cat([
 col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
 row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)

print(pos.shape)
torch.Size([400, 1, 256])
print(col_embed.shape)
print(col_embed[:W].shape)
print(col_embed[:W].unsqueeze(0).shape)
print(col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape)
print(row_embed[:H].unsqueeze(1).repeat(1, W, 1).shape)
torch.Size([50, 128])
torch.Size([20, 128])
torch.Size([1, 20, 128])
torch.Size([20, 20, 128])
torch.Size([20, 20, 128])
print(torch.cat([
 col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
 row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).shape)

print(torch.cat([
 col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
 row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).shape)

print(torch.cat([
 col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
 row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1).shape)
torch.Size([20, 20, 256])
torch.Size([400, 256])
torch.Size([400, 1, 256])
# propagate through the transformer
h = torch.rand(1, 256, 20, 20)
print(h.flatten(2).shape)
print(h.flatten(2).permute(2, 0, 1).shape)
print(query_pos.unsqueeze(1).shape)
# h = transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
#             query_pos.unsqueeze(1)).transpose(0, 1)
torch.Size([1, 256, 400])
torch.Size([400, 1, 256])
torch.Size([100, 1, 256])
hidden_dim = 256
nheads = 8
num_encoder_layers = 6
num_decoder_layers = 6
transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
print(transformer(pos + 0.1*h.flatten(2).permute(2, 0, 1),
         query_pos.unsqueeze(1)).transpose(0, 1).shape)
torch.Size([1, 100, 256])