原项目:https://github.com/IdoSpringer/ERGO-II

模型概述

  • 输入:TCR(包括alpha链序列,beta链序列,V基因,J基因)和肽链(包括氨基酸序列和MHC(HLA)信息)
  • 处理:AE编码或LSTM得到嵌入并拼接,传递给前向网络
  • 输出:分数Score,解释为是否结合

复现测试部分

Evaluation.py

由于原论文只给了1veajht和1meajht两个版本的checkpoints,mcpas和vdjdb两个数据集对应的预训练参数分别只对应的一个checkpoints。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
if __name__ == '__main__':
version = '1veajht'
key = 'vdjdb'
if key == 'mcpas':
freq_peps = Sampler.frequent_peptides('data/McPAS-TCR.csv', 'mcpas', 20)
spb_table = pd.DataFrame()
# for version in ['1me', '1ml', '1mea', '1mla', '1meaj', '1mlaj',
# '1meajh', '1mlajh', '1meajht', '1mlajht']:
print(version)
spb_results = spb_main(version, data_key='mcpas')
print(spb_results)
spb_table[version] = spb_results
spb_table.index = freq_peps
spb_table.to_csv('plots/AE_mcpas_spb_results.csv')
elif key == 'vdjdb':
freq_peps = ['IPSINVHHY', 'TPRVTGGGAM', 'NLVPMVATV', 'GLCTLVAML',
'RAKFKQLL', 'YVLDHLIVV', 'GILGFVFTL', 'PKYVKQNTLKLAT',
'CINGVCWTV', 'KLVALGINAV', 'ATDALMTGY', 'RPRGEVRFL',
'LLWNGPMAV', 'GTSGSPIVNR', 'GTSGSPIINR', 'KAFSPEVIPMF',
'TPQDLNTML', 'EIYKRWII', 'KRWIILGLNK', 'FRDYVDRFYKTLRAEQASQE',
'GPGHKARVL', 'FLKEKGGL']
spb_table = pd.DataFrame()
# for version in ['1fe', '1fl', '1fea', '1fla', '1feaj', '1flaj',
# '1feajh', '1flajh', '1feajht', '1flajht']:
print(version)
spb_results = spb_main(version, data_key='vdjdb')
print(spb_results)
spb_table[version] = spb_results
spb_table.index = freq_peps
spb_table.to_csv('plots/AE_vdjdb_spb_results.csv')
pass

测试自己的数据集

由于自己的数据集只有TCR-beta链序列,肽序列和HLA,故修改脚本,将alpha链,V基因和J基因全部打上UNK

Predict.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
import pickle
from Loader import SignedPairsDataset, get_index_dicts
from Trainer import ERGOLightning
from torch.utils.data import DataLoader
from argparse import Namespace
import torch
import pandas as pd
import os
from os import listdir
from os.path import isfile, join
import sys

def read_input_file(datafile):
amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
all_pairs = []
def invalid(seq):
return pd.isna(seq) or any([aa not in amino_acids for aa in seq])
data = pd.read_csv(datafile)
for index in range(len(data)):
sample={}
sample['tcrb'] = data['CDR3'][index]
sample['peptide'] = data['Antigen'][index]

# 补充 MHC 列并转换格式
hla_value = data.get('HLA', [None])[index] # 读取 HLA 列
if pd.isna(hla_value):
sample['mhc'] = 'UNK'
else:
sample['mhc'] = f"HLA-{hla_value.replace('*', '')}" # 转换为 MHC 格式

# 补充 TCR alpha 链 (全为 UNK)
sample['tcra'] = 'UNK'

# 补充其他列为默认值
sample['va'] = 'UNK'
sample['ja'] = 'UNK'
sample['vb'] = 'UNK'
sample['jb'] = 'UNK'
sample['t_cell_type'] = 'UNK'
sample['protein'] = 'UNK'
sample['sign'] = 0 # 默认 sign 为 0
if invalid(sample['tcrb']) or invalid(sample['peptide']):
continue
all_pairs.append(sample)
return all_pairs, data


def load_model(hparams, checkpoint_path):
model = ERGOLightning(hparams)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model


def get_model(dataset):
if dataset == 'vdjdb':
version = '1veajht'
if dataset == 'mcpas':
version = '1meajht'
# get model file from version
checkpoint_path = os.path.join('Models', 'version_' + version, 'checkpoints')
files = [f for f in listdir(checkpoint_path) if isfile(join(checkpoint_path, f))]
checkpoint_path = os.path.join(checkpoint_path, files[0])
# get args from version
args_path = os.path.join('Models', 'version_' + version, 'meta_tags.csv')
with open(args_path, 'r') as file:
lines = file.readlines()
args = {}
for line in lines[1:]:
key, value = line.strip().split(',')
if key in ['dataset', 'tcr_encoding_model', 'cat_encoding']:
args[key] = value
else:
args[key] = eval(value)
hparams = Namespace(**args)
checkpoint = checkpoint_path
model = load_model(hparams, checkpoint)
train_pickle = 'Samples/' + model.dataset + '_train_samples.pickle'
test_pickle = 'Samples/' + model.dataset + '_test_samples.pickle'
return model, train_pickle,test_pickle


def get_train_dicts(train_pickle):
with open(train_pickle, 'rb') as handle:
train = pickle.load(handle)
train_dicts = get_index_dicts(train)
return train_dicts


def predict(dataset, test_file):
model, train_file,test_pickle = get_model(dataset)
train_dicts = get_train_dicts(train_file)
test_samples, dataframe = read_input_file(test_file)
test_dataset = SignedPairsDataset(test_samples, train_dicts)
batch_size = 1000
loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
collate_fn=lambda b: test_dataset.collate(b, tcr_encoding=model.tcr_encoding_model,
cat_encoding=model.cat_encoding))
outputs = []

def predict(dataset, test_file):
model, train_file, test_pickle = get_model(dataset)
train_dicts = get_train_dicts(train_file)
test_samples, dataframe = read_input_file(test_file)
test_dataset = SignedPairsDataset(test_samples, train_dicts)
batch_size = 1000
loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
collate_fn=lambda b: test_dataset.collate(b, tcr_encoding=model.tcr_encoding_model,
cat_encoding=model.cat_encoding))
outputs = []
for batch_idx, batch in enumerate(loader):
output = model.validation_step(batch, batch_idx)
if output:
outputs.extend(output['y_hat'].tolist())

# Verify and log length mismatch
if len(outputs) != len(dataframe):
print(f"Length mismatch detected!")
print(f"Outputs length: {len(outputs)}")
print(f"Dataframe length: {len(dataframe)}")
print(f"Sample outputs: {outputs[:5]}")
raise ValueError(f"Length mismatch: outputs ({len(outputs)}) vs dataframe ({len(dataframe)})")

dataframe['Score'] = outputs
return dataframe


if __name__ == '__main__':
df = predict('mcpas','test_neg_part1.csv')
print(df)
df.to_csv('results/2.csv', index=False)
pass


# NOTE: fix sklearn import problem with this in terminal:
# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/dsi/speingi/anaconda3/lib/
# or just conda install libgcc

发现过滤问题

据推测脚本对序列不合法的情况有滤过行为,导致输入行数和输出行数不对齐,故修改Predict脚本,只输出过滤后的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import pickle
from Loader import SignedPairsDataset, get_index_dicts
from Trainer import ERGOLightning
from torch.utils.data import DataLoader
from argparse import Namespace
import torch
import pandas as pd
import os
from os import listdir
from os.path import isfile, join
import sys

def read_input_file(datafile):
amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV']
all_pairs = []
def invalid(seq):
return pd.isna(seq) or any([aa not in amino_acids for aa in seq])
data = pd.read_csv(datafile)
for index in range(len(data)):
sample={}
sample['tcrb'] = data['CDR3'][index]
sample['peptide'] = data['Antigen'][index]

# 补充 MHC 列并转换格式
hla_value = data.get('HLA', [None])[index] # 读取 HLA 列
if pd.isna(hla_value):
sample['mhc'] = 'UNK'
else:
sample['mhc'] = f"HLA-{hla_value.replace('*', '')}" # 转换为 MHC 格式

# 补充 TCR alpha 链 (全为 UNK)
sample['tcra'] = 'UNK'

# 补充其他列为默认值
sample['va'] = 'UNK'
sample['ja'] = 'UNK'
sample['vb'] = 'UNK'
sample['jb'] = 'UNK'
sample['t_cell_type'] = 'UNK'
sample['protein'] = 'UNK'
sample['sign'] = 0 # 默认 sign 为 0
if invalid(sample['tcrb']) or invalid(sample['peptide']):
continue
all_pairs.append(sample)
return all_pairs, data


def load_model(hparams, checkpoint_path):
model = ERGOLightning(hparams)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model


def get_model(dataset):
if dataset == 'vdjdb':
version = '1veajht'
if dataset == 'mcpas':
version = '1meajht'
# get model file from version
checkpoint_path = os.path.join('Models', 'version_' + version, 'checkpoints')
files = [f for f in listdir(checkpoint_path) if isfile(join(checkpoint_path, f))]
checkpoint_path = os.path.join(checkpoint_path, files[0])
# get args from version
args_path = os.path.join('Models', 'version_' + version, 'meta_tags.csv')
with open(args_path, 'r') as file:
lines = file.readlines()
args = {}
for line in lines[1:]:
key, value = line.strip().split(',')
if key in ['dataset', 'tcr_encoding_model', 'cat_encoding']:
args[key] = value
else:
args[key] = eval(value)
hparams = Namespace(**args)
checkpoint = checkpoint_path
model = load_model(hparams, checkpoint)
train_pickle = 'Samples/' + model.dataset + '_train_samples.pickle'
test_pickle = 'Samples/' + model.dataset + '_test_samples.pickle'
return model, train_pickle,test_pickle


def get_train_dicts(train_pickle):
with open(train_pickle, 'rb') as handle:
train = pickle.load(handle)
train_dicts = get_index_dicts(train)
return train_dicts


def predict(dataset, test_file):
model, train_file,test_pickle = get_model(dataset)
train_dicts = get_train_dicts(train_file)
test_samples, dataframe = read_input_file(test_file)
print(f"原始数据框行数: {len(dataframe)}") # 打印原始行数
print("检查数据框的缺失值情况:")
print(dataframe.isnull().sum()) # 检查缺失值

test_dataset = SignedPairsDataset(test_samples, train_dicts)
batch_size = 1000
loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
collate_fn=lambda b: test_dataset.collate(b, tcr_encoding=model.tcr_encoding_model,
cat_encoding=model.cat_encoding))
outputs = []

total_batches = 0 # 记录总批次数
total_samples = 0 # 记录总样本数
for batch_idx, batch in enumerate(loader):
total_batches += 1
if not batch: # 如果批次为空
print(f"跳过空批次: {batch_idx}")
continue
output = model.validation_step(batch, batch_idx)
if output:
outputs.extend(output['y_hat'].tolist())
total_samples += len(output['y_hat'])
print(f"批次 {batch_idx}: 样本数 {len(output['y_hat'])}") # 打印每批的样本数

print(f"处理完成的总样本数: {total_samples}")
print(f"模型预测的总输出数: {len(outputs)}")

print(f"Original dataframe indices: {dataframe.index.tolist()}")
print(f"Generated outputs indices: {list(range(len(outputs)))}")

dataframe['Score'] = outputs
return dataframe

if __name__ == '__main__':
df = predict('mcpas','test_neg_part1.csv')
print(df)
df.to_csv('results/2.csv', index=False)
pass


# NOTE: fix sklearn import problem with this in terminal:
# export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/dsi/speingi/anaconda3/lib/
# or just conda install libgcc

模型最终输入6190条数据,得到5993条预测结果。