1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import esm


# 定义位置编码类
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)

def forward(self, x):
pe = self.pe[:x.size(0), :].to(x.device) # 确保位置编码在同一设备上
x = x + pe
return x
# 定义一个Transformer模型,包括自注意力层和FFN层
class ProteinTransformer(nn.Module):
def __init__(self, d_model=1280, nhead=4, num_encoder_layers=1, dim_feedforward=5120):
super(ProteinTransformer, self).__init__()
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
self.d_model = d_model
def forward(self, src):
output = self.transformer_encoder(src)
return output


def load_esm1b_model():
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
return model, alphabet

def extract_embeddings(sequence, model, alphabet, device):
batch_converter = alphabet.get_batch_converter()
data = [("protein1", sequence)]
_, _, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.to(device) # 确保batch_tokens在同一设备上

with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]

# 去掉batch维度,只保留原始序列长度的嵌入
embedding = token_representations.squeeze(0)
original_length = len(sequence)
embedding = embedding[:original_length, :]

return embedding


def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
esm_model, alphabet = load_esm1b_model()
esm_model = esm_model.to(device) # 将模型移动到同一设备上
sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAKLFAIK"# 输入蛋白质序列
embeddings = extract_embeddings(sequence, esm_model, alphabet, device).to(device) # 从ESM-1b模型提取嵌入
pos_encoder = PositionalEncoding(d_model=embeddings.size(1), max_len=embeddings.size(0)).to(device) # 初始化位置编码
embeddings_with_pos = pos_encoder(embeddings.unsqueeze(1)) # 应用位置编码
transformer_model = ProteinTransformer().to(device)# 初始化Transformer模型
with torch.no_grad():
transformed_embeddings = transformer_model(embeddings_with_pos)
print("Transformed Embedding Matrix Shape:", transformed_embeddings.shape)
print("Transformed Embedding Matrix:", transformed_embeddings)


if __name__ == "__main__":
main()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Transformed Embedding Matrix Shape: torch.Size([40, 1, 1280])
Transformed Embedding Matrix: tensor([[[-0.9183, 0.8383, -0.5008, ..., 0.8783, -0.6543, 0.4944]],

[[ 0.3480, -1.0895, 0.5983, ..., 0.7502, -0.7020, 0.1589]],

[[ 0.5885, -2.2407, 0.2032, ..., 0.7158, -0.7832, 0.8766]],

...,

[[-0.8607, 0.1170, -1.3206, ..., 0.8361, -0.3275, 0.9041]],

[[-0.1696, 0.6289, -0.8771, ..., 1.0834, -0.2475, 0.9045]],

[[ 0.6635, -0.5849, 0.3174, ..., 0.9572, -0.4215, 1.1555]]],
device='cuda:0')