1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
| import numpy as np import pickle from Loader import SignedPairsDataset, get_index_dicts from Trainer import ERGOLightning from torch.utils.data import DataLoader from argparse import Namespace import torch import pandas as pd import os from os import listdir from os.path import isfile, join import sys
def read_input_file(datafile): amino_acids = [letter for letter in 'ARNDCEQGHILKMFPSTWYV'] all_pairs = [] def invalid(seq): return pd.isna(seq) or any([aa not in amino_acids for aa in seq]) data = pd.read_csv(datafile) for index in range(len(data)): sample={} sample['tcrb'] = data['CDR3'][index] sample['peptide'] = data['Antigen'][index]
hla_value = data.get('HLA', [None])[index] if pd.isna(hla_value): sample['mhc'] = 'UNK' else: sample['mhc'] = f"HLA-{hla_value.replace('*', '')}"
sample['tcra'] = 'UNK'
sample['va'] = 'UNK' sample['ja'] = 'UNK' sample['vb'] = 'UNK' sample['jb'] = 'UNK' sample['t_cell_type'] = 'UNK' sample['protein'] = 'UNK' sample['sign'] = 0 if invalid(sample['tcrb']) or invalid(sample['peptide']): continue all_pairs.append(sample) return all_pairs, data
def load_model(hparams, checkpoint_path): model = ERGOLightning(hparams) checkpoint = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint['state_dict']) model.eval() return model
def get_model(dataset): if dataset == 'vdjdb': version = '1veajht' if dataset == 'mcpas': version = '1meajht' checkpoint_path = os.path.join('Models', 'version_' + version, 'checkpoints') files = [f for f in listdir(checkpoint_path) if isfile(join(checkpoint_path, f))] checkpoint_path = os.path.join(checkpoint_path, files[0]) args_path = os.path.join('Models', 'version_' + version, 'meta_tags.csv') with open(args_path, 'r') as file: lines = file.readlines() args = {} for line in lines[1:]: key, value = line.strip().split(',') if key in ['dataset', 'tcr_encoding_model', 'cat_encoding']: args[key] = value else: args[key] = eval(value) hparams = Namespace(**args) checkpoint = checkpoint_path model = load_model(hparams, checkpoint) train_pickle = 'Samples/' + model.dataset + '_train_samples.pickle' test_pickle = 'Samples/' + model.dataset + '_test_samples.pickle' return model, train_pickle,test_pickle
def get_train_dicts(train_pickle): with open(train_pickle, 'rb') as handle: train = pickle.load(handle) train_dicts = get_index_dicts(train) return train_dicts
def predict(dataset, test_file): model, train_file,test_pickle = get_model(dataset) train_dicts = get_train_dicts(train_file) test_samples, dataframe = read_input_file(test_file) test_dataset = SignedPairsDataset(test_samples, train_dicts) batch_size = 1000 loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: test_dataset.collate(b, tcr_encoding=model.tcr_encoding_model, cat_encoding=model.cat_encoding)) outputs = []
def predict(dataset, test_file): model, train_file, test_pickle = get_model(dataset) train_dicts = get_train_dicts(train_file) test_samples, dataframe = read_input_file(test_file) test_dataset = SignedPairsDataset(test_samples, train_dicts) batch_size = 1000 loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: test_dataset.collate(b, tcr_encoding=model.tcr_encoding_model, cat_encoding=model.cat_encoding)) outputs = [] for batch_idx, batch in enumerate(loader): output = model.validation_step(batch, batch_idx) if output: outputs.extend(output['y_hat'].tolist())
if len(outputs) != len(dataframe): print(f"Length mismatch detected!") print(f"Outputs length: {len(outputs)}") print(f"Dataframe length: {len(dataframe)}") print(f"Sample outputs: {outputs[:5]}") raise ValueError(f"Length mismatch: outputs ({len(outputs)}) vs dataframe ({len(dataframe)})")
dataframe['Score'] = outputs return dataframe
if __name__ == '__main__': df = predict('mcpas','test_neg_part1.csv') print(df) df.to_csv('results/2.csv', index=False) pass
|