/* PSIPRED 3 - Neural Network Prediction of Secondary Structure */

/* Copyright (C) 2000 David T. Jones - Created : January 2000 */
/* Original Neural Network code Copyright (C) 1990 David T. Jones */

/* Average Prediction Module */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <ctype.h>
#include <time.h>

#include "ssdefs.h"
#include "sspred_net.h"

//#define DEBUG
//#define DEBUG_LOG

// complexity of psipred is 89516 * 615 = 53822340, about 50 million times multiplication
char weightFileList[5][300]={ "" };
int numWeightFile=0;
void           *calloc(), *malloc();

char           *wtfnm;

int             nwtsum, fwt_to[TOTAL], lwt_to[TOTAL];
float           activation[TOTAL], bias[TOTAL], *weight[TOTAL];

int             profile[MAXSEQLEN][20];

int             seqlen;

char seq[MAXSEQLEN];

enum aacodes
{
    ALA, ARG, ASN, ASP, CYS,
    GLN, GLU, GLY, HIS, ILE,
    LEU, LYS, MET, PHE, PRO,
    SER, THR, TRP, TYR, VAL,
    UNK
};

void err(char *s)
{
    fprintf(stderr, "%s\n", s);
}

void fail(char *s)
{
    err(s);
    exit(1);
}

void compute_output_new(float **weight_local, float* bias_local)/*{{{*/
{
    int             i, j;
    float           netinp;

    for (i = NUM_IN; i < TOTAL; i++)
	{
		netinp = bias_local[i];
		for (j = fwt_to[i]; j < lwt_to[i]; j++)
		{ /*this step cost >95% of computational time*/
			netinp += activation[j] * weight_local[i][j];
		}

		/* Trigger neuron */
		activation[i] = logistic(netinp);
	}
}/*}}}*/
void compute_output(void)
{
    int             i, j;
    float           netinp;

    for (i = NUM_IN; i < TOTAL; i++)
    {
        netinp = bias[i];

        for (j = fwt_to[i]; j < lwt_to[i]; j++)     /*this step takes >50% time*/
        {
            netinp += activation[j] * weight[i][j];
        }

        /* Trigger neuron */
        activation[i] = logistic(netinp); /*this step does not take much time*/
//        activation[i] = netinp;
#ifdef DEBUG_LOG
        //fprintf(stdout,"netinp = %g\n", netinp);
        fprintf(stdout,"fwt_to[%d] = %d  lwt_to[%d] = %d\n", i, fwt_to[i], i, lwt_to[i]);
        /*fwt_to .. lwt_to can be 0 .. 615 or 616 .. 690*/
#endif 
    }
}
char *rootname(const char* filename, char* rtname)/*{{{*/
/*****************************************************************************
 * rootname
 * given the file name, 
 * return the rootname of the filename
 ****************************************************************************/
{
    const char *pch;
    char *pstr;
    if((pch = strrchr(filename,'/')) != NULL) {
        pstr = (char*) pch+1;
    } else { 
        pstr = (char*) filename;
    }

    if((pch = strrchr(pstr,'.')) != NULL) {
        strncpy(rtname,pstr, (int)(pch - pstr));
        rtname[pch-pstr]='\0';
    } else {
        rtname = pstr;
    }
    return rtname;
}
/*}}}*/
char** ReadFileList(const char* filename, char **filenameList, int* p_cntfile)/*{{{*/
{
    int cntfile=0;
    int i = 0;
    FILE* fpin = fopen(filename,"r");
    if( fpin == NULL) {
        fail("Unable to open listFile!");
    }
    if (fseek(fpin, 0 , SEEK_END) != 0) {
        fail("fseek of the listFile failed");
    }
    int filesize = ftell(fpin);
    char *string = 0;
    /*read in the while file to string*/
    string = malloc((sizeof(char)*(filesize+1)));
    if (fseek(fpin,0,SEEK_SET) != 0) {
        fail("fseek of the listFile failed");
    }
    if (fread(string, sizeof(char),filesize, fpin) != filesize) {
        fail("fread of the listFile failed");
    }
    fclose(fpin);
    string[filesize]=0;


    /*get the number of lines*/
    int cntline=0;
    for (i = 0; i < filesize; i++) {
        if (string[i] == '\n') {
            cntline++;
        }
    }
    filenameList=malloc(sizeof(char*)*(cntline+1));
    /*parse the string to filenameList*/
    char *pch;
    pch = strtok(string, "\n");
    int n=0;
    while(pch!=NULL)
    {
        n = strlen(pch);
        if (n > 0)
        {
            filenameList[cntfile] = malloc(sizeof(char)*(n+1));
            strcpy(filenameList[cntfile], pch);
            cntfile ++;
        }
        pch = strtok(NULL,"\n");
    }
#ifdef DEBUG
    for (i=0;i<cntfile; i++)
    {
        fprintf(stderr,"ReadFileList, %d:%s\n", i, filenameList[i]);
    }
#endif 
    free(string);
    *p_cntfile = cntfile;
    return filenameList;

}/*}}}*/

void load_wts_new(char *fname, float **weight_local, float* bias_local)/*{{{*/
{
#ifdef DEBUG
	fprintf(stderr,"load weight file:%s\n", fname);
#endif
	int             i, j;
	double          t, chksum = 0.0;
	FILE           *ifp;

	if (!(ifp = fopen(fname, "r")))
		fail("Cannot open weight file!\n");

	/* Load input units to hidden layer weights */
	for (i = NUM_IN; i < NUM_IN + NUM_HID; i++)
		for (j = fwt_to[i]; j < lwt_to[i]; j++)
		{
			fscanf(ifp, "%lf", &t);
			weight_local[i][j] = t;
			chksum += t*t;
		}

	/* Load hidden layer to output units weights */
	for (i = NUM_IN + NUM_HID; i < TOTAL; i++)
		for (j = fwt_to[i]; j < lwt_to[i]; j++)
		{
			fscanf(ifp, "%lf", &t);
			weight_local[i][j] = t;
			chksum += t*t;
		}

	/* Load bias weights */
	for (j = NUM_IN; j < TOTAL; j++)
	{
		fscanf(ifp, "%lf", &t);
		bias_local[j] = t;
		chksum += t*t;
	}

	/* Read expected checksum at end of file */
	if (fscanf(ifp, "%lf", &t) != 1 || ferror(ifp))
		fail("Weight file read error!");

	fclose(ifp);

	if ((int)t != (int)(chksum+0.5))
		fail("Corrupted weight file detected!");
}/*}}}*/
/*
 * load weights - load all link weights from a disk file
 */
void
load_wts(char *fname)
{
#ifdef DEBUG
	fprintf(stderr,"load weight file:%s\n", fname);
#endif
    int             i, j;
    double          t, chksum = 0.0;
    FILE           *ifp;

    if (!(ifp = fopen(fname, "r")))
	fail("Cannot open weight file!\n");

    /* Load input units to hidden layer weights */
    for (i = NUM_IN; i < NUM_IN + NUM_HID; i++)
	for (j = fwt_to[i]; j < lwt_to[i]; j++)
	{
	    fscanf(ifp, "%lf", &t);
	    weight[i][j] = t;
	    chksum += t*t;
	}

    /* Load hidden layer to output units weights */
    for (i = NUM_IN + NUM_HID; i < TOTAL; i++)
	for (j = fwt_to[i]; j < lwt_to[i]; j++)
	{
	    fscanf(ifp, "%lf", &t);
	    weight[i][j] = t;
	    chksum += t*t;
	}

    /* Load bias weights */
    for (j = NUM_IN; j < TOTAL; j++)
    {
	fscanf(ifp, "%lf", &t);
	bias[j] = t;
	chksum += t*t;
    }

    /* Read expected checksum at end of file */
    if (fscanf(ifp, "%lf", &t) != 1 || ferror(ifp))
	fail("Weight file read error!");

    fclose(ifp);

    if ((int)t != (int)(chksum+0.5))
	fail("Corrupted weight file detected!");
}

float ***init_new(float ***weight_list)/*{{{*/
{
    int             i, j;

	int list =0;
	weight_list=malloc(sizeof(float**)*numWeightFile);
	for (list = 0; list < numWeightFile; list ++)
	{
		weight_list[list] = malloc(sizeof(float*)*TOTAL);
		for (i=0;i <NUM_IN; i++){
			weight_list[list][i] = NULL;
		}
		for (i = NUM_IN; i < TOTAL; i++) {
			if (!(weight_list[list][i] = calloc(TOTAL - NUM_OUT, sizeof(float)))) {
				fail("init: Out of Memory!");
			}
		}
	}
    /* Connect input units to hidden layer */
    for (i = NUM_IN; i < NUM_IN + NUM_HID; i++)
    {
	fwt_to[i] = 0;
	lwt_to[i] = NUM_IN;
    }

    /* Connect hidden units to output layer */
    for (i = NUM_IN + NUM_HID; i < TOTAL; i++)
    {
	fwt_to[i] = NUM_IN;
	lwt_to[i] = NUM_IN + NUM_HID;
    }
	return weight_list;
}/*}}}*/
/* Initialize network */
void
init(void)
{
    int             i, j;

    for (i = NUM_IN; i < TOTAL; i++)
	if (!(weight[i] = calloc(TOTAL - NUM_OUT, sizeof(float))))
	  fail("init: Out of Memory!");

    /* Connect input units to hidden layer */
    for (i = NUM_IN; i < NUM_IN + NUM_HID; i++)
    {
	fwt_to[i] = 0;
	lwt_to[i] = NUM_IN;
    }

    /* Connect hidden units to output layer */
    for (i = NUM_IN + NUM_HID; i < TOTAL; i++)
    {
	fwt_to[i] = NUM_IN;
	lwt_to[i] = NUM_IN + NUM_HID;
    }
}

/* Convert AA letter to numeric code (0-20) */
int
aanum(ch)
    int             ch;
{
    static const int      aacvs[] =
    {
	999, 0, 20, 4, 3, 6, 13, 7, 8, 9, 20, 11, 10, 12, 2,
	20, 14, 5, 1, 15, 16, 20, 19, 17, 20, 18, 20
    };

    return (isalpha(ch) ? aacvs[ch & 31] : 20);
}

void predict_new(float*** weight_list, float**bias_list, FILE* fpout)/*{{{*/
{
	int             aa, i, j, k, n, winpos,ws;
	char fname[80], predsst[MAXSEQLEN];
	float           avout[MAXSEQLEN][3], conf, confsum[MAXSEQLEN];

	for (winpos = 0; winpos < seqlen; winpos++)
		avout[winpos][0] = avout[winpos][1] = avout[winpos][2] = confsum[winpos] = 0.0F;

	for (ws=0; ws<numWeightFile; ws++)
	{
//		load_wts(weightFileList[ws]);

		for (winpos = 0; winpos < seqlen; winpos++)
		{
			for (j = 0; j < NUM_IN; j++)
				activation[j] = 0.0;
			for (j = WINL; j <= WINR; j++)
			{
				if (j + winpos >= 0 && j + winpos < seqlen)
				{
					for (aa=0; aa<20; aa++)
						activation[(j - WINL) * IPERGRP + aa] = profile[j+winpos][aa]/1000.0;
					aa = aanum(seq[j+winpos]);
					if (aa < 20)
						activation[(j - WINL) * IPERGRP + 20 + aa] = 1.0;
					else
						activation[(j - WINL) * IPERGRP + 40] = 1.0;
				}
				else
					activation[(j - WINL) * IPERGRP + 40] = 1.0;
			}

		   compute_output_new(weight_list[ws], bias_list[ws]);
		   // compute_output();

			conf = 1.0 - MIN(MIN(activation[TOTAL - NUM_OUT], activation[TOTAL - NUM_OUT+1]), activation[TOTAL - NUM_OUT+2]);

			avout[winpos][0] += conf * activation[TOTAL - NUM_OUT];
			avout[winpos][1] += conf * activation[TOTAL - NUM_OUT+1];
			avout[winpos][2] += conf * activation[TOTAL - NUM_OUT+2];
			confsum[winpos] += conf;
		}
	}

	for (winpos = 0; winpos < seqlen; winpos++)
	{
		avout[winpos][0] /= confsum[winpos];
		avout[winpos][1] /= confsum[winpos];
		avout[winpos][2] /= confsum[winpos];
		if (avout[winpos][0] >= MAX(avout[winpos][1], avout[winpos][2]))
			predsst[winpos] = 'C';
		else if (avout[winpos][2] >= MAX(avout[winpos][0], avout[winpos][1]))
			predsst[winpos] = 'E';
		else
			predsst[winpos] = 'H';
	}

	for (winpos = 0; winpos < seqlen; winpos++)
		fprintf(fpout, "%4d %c %c  %6.3f %6.3f %6.3f\n", winpos + 1, seq[winpos], predsst[winpos], avout[winpos][0], avout[winpos][1], avout[winpos][2]);

}/*}}}*/
/* Make 1st level prediction averaged over specified weight sets */
void
predict(int argc, char **argv)
{
	int             aa, i, j, k, n, winpos,ws;
	char fname[80], predsst[MAXSEQLEN];
	float           avout[MAXSEQLEN][3], conf, confsum[MAXSEQLEN];

	for (winpos = 0; winpos < seqlen; winpos++)
		avout[winpos][0] = avout[winpos][1] = avout[winpos][2] = confsum[winpos] = 0.0F;

	for (ws=2; ws<argc; ws++)
	{
		load_wts(argv[ws]);

		for (winpos = 0; winpos < seqlen; winpos++)
		{
			for (j = 0; j < NUM_IN; j++)
				activation[j] = 0.0;
			for (j = WINL; j <= WINR; j++)
			{
				if (j + winpos >= 0 && j + winpos < seqlen)
				{
					for (aa=0; aa<20; aa++)
						activation[(j - WINL) * IPERGRP + aa] = profile[j+winpos][aa]/1000.0;
					aa = aanum(seq[j+winpos]);
					if (aa < 20)
						activation[(j - WINL) * IPERGRP + 20 + aa] = 1.0;
					else
						activation[(j - WINL) * IPERGRP + 40] = 1.0;
				}
				else
					activation[(j - WINL) * IPERGRP + 40] = 1.0;
			}

			compute_output();

			conf = 1.0 - MIN(MIN(activation[TOTAL - NUM_OUT], activation[TOTAL - NUM_OUT+1]), activation[TOTAL - NUM_OUT+2]);

			avout[winpos][0] += conf * activation[TOTAL - NUM_OUT];
			avout[winpos][1] += conf * activation[TOTAL - NUM_OUT+1];
			avout[winpos][2] += conf * activation[TOTAL - NUM_OUT+2];
			confsum[winpos] += conf;
		}
	}

	for (winpos = 0; winpos < seqlen; winpos++)
	{
		avout[winpos][0] /= confsum[winpos];
		avout[winpos][1] /= confsum[winpos];
		avout[winpos][2] /= confsum[winpos];
		if (avout[winpos][0] >= MAX(avout[winpos][1], avout[winpos][2]))
			predsst[winpos] = 'C';
		else if (avout[winpos][2] >= MAX(avout[winpos][0], avout[winpos][1]))
			predsst[winpos] = 'E';
		else
			predsst[winpos] = 'H';
	}

	for (winpos = 0; winpos < seqlen; winpos++)
		printf("%4d %c %c  %6.3f %6.3f %6.3f\n", winpos + 1, seq[winpos], predsst[winpos], avout[winpos][0], avout[winpos][1], avout[winpos][2]);
}

/* Read PSI AA frequency data */
int             getmtx(FILE *lfil)
{
    int             aa, i, j, naa;
    char            buf[256], *p;
    
    if (fscanf(lfil, "%d", &naa) != 1)
	fail("Bad mtx file - no sequence length!");
    
    if (naa > MAXSEQLEN)
	fail("Input sequence too long!");
    
    if (fscanf(lfil, "%s", seq) != 1)
	fail("Bad mtx file - no sequence!");
    
    while (!feof(lfil))
    {
	if (!fgets(buf, 65536, lfil))
	    fail("Bad mtx file!");
	if (!strncmp(buf, "-32768 ", 7))
	{
	    for (j=0; j<naa; j++)
	    {
		if (sscanf(buf, "%*d%d%*d%d%d%d%d%d%d%d%d%d%d%d%d%d%d%d%d%d%d%*d%d", &profile[j][ALA],  &profile[j][CYS], &profile[j][ASP],  &profile[j][GLU],  &profile[j][PHE],  &profile[j][GLY],  &profile[j][HIS],  &profile[j][ILE],  &profile[j][LYS],  &profile[j][LEU],  &profile[j][MET],  &profile[j][ASN],  &profile[j][PRO],  &profile[j][GLN],  &profile[j][ARG],  &profile[j][SER],  &profile[j][THR],  &profile[j][VAL],  &profile[j][TRP],  &profile[j][TYR]) != 20)
		    fail("Bad mtx format!");
		aa = aanum(seq[j]);
		if (aa < 20)
		    profile[j][aa] += 0000;
		if (!fgets(buf, 65536, lfil))
		    break;
	    }
	}
    }
    
    return naa;
}
void PrintHelp()
{
    printf("usage:  psipred  (mtx-file OR -l mtxListFile)  weight-file1 ... weight-filen\n");
    printf("\n");
    printf("options:\n");
    printf("    -outpath <dir> : set output path, default=./\n");
    printf("    -l <file>      : set the list file\n");
    printf("    -h|--help  : print this help message\n");
    printf("\n");
    printf("updated 2010-11-19, Nanjiang\n");
}

main(int argc, char **argv)
{
    int             i,j, niters;
    FILE *ifp;
	char mtxFile[500]="";
	char listFile[500]="";
	char outpath[500]="./";
    
    /* malloc_debug(3); */
    if (argc < 2)
    {
        PrintHelp();
        return 1;
    }
    i = 1;
	int cnt = 0;/*for weight file*/
	while(i < argc )
	{
		if (argv[i][0] == '-')
		{
			if(strcmp(argv[i],"-h") == 0 ||strcmp(argv[i],"--help")==0 )
			{
				PrintHelp();
				break;
			}
			else if(strcmp(argv[i],"-outpath") == 0||strcmp(argv[i],"--outpath") == 0)
			{
				strcpy(outpath,argv[i+1]);
				i += 2;
			}
			else if(strcmp(argv[i],"-l") == 0||strcmp(argv[i],"--l") == 0)
			{
				strcpy(listFile,argv[i+1]);
				i += 2;
			}
			else
			{
				fprintf(stderr,"Wrong argument:%s\n",argv[i]);
				exit(1);
			}
		}
	    else
        {
			if(strcmp(listFile, "") == 0 && strcmp(mtxFile,"") == 0)
			{
				strcpy(mtxFile,argv[i]);
			}
			else if (cnt < 5)
			{
				strcpy(weightFileList[cnt], argv[i]);
				cnt ++;
			}
            i += 1;
        }
	}
	numWeightFile=cnt ;

	if (strcmp(mtxFile, "") != 0 && strcmp(listFile,"")!=0){
		fail("Either mtxFile or listFile should be set!");
	}
	if (strcmp(mtxFile, "") == 0 && strcmp(listFile,"")==0){
		fail("Either mtxFile or listFile should be set!");
	}

	for (i=0;i< TOTAL;i++) {
		weight[i]=NULL;
	}

    /*if (argc < 2)*/
	/*fail("usage : psipred mtx-file weight-file1 ... weight-filen");*/
	if (strcmp(mtxFile, "") != 0)/*{{{*/
	{
#ifdef DEBUG
		fprintf(stderr,"predicting with %s\n", mtxFile);
#endif
		ifp = fopen(mtxFile, "r");
		if (!ifp){ exit(1); }
		seqlen = getmtx(ifp);
		fclose(ifp);

		init();
		predict(argc,argv);
	}/*}}}*/
	else if (strcmp(listFile,"")!= 0)/*{{{*/
	{
        char cmd[500]="";
        sprintf(cmd, "mkdir -p %s", outpath);
        system(cmd);

        int cntfile=0;
        char **filenameList=NULL;
        filenameList = ReadFileList(listFile,filenameList, &cntfile);
		int ifile=0;
		int ws =0 ;
		float ***weight_list = NULL;
		float **bias_list= malloc(sizeof(float*)*numWeightFile);
		weight_list = init_new(weight_list);
		init();
		for (ws=0; ws<numWeightFile; ws++)
		{
			bias_list[ws] = malloc(sizeof(float*)*TOTAL);
			load_wts_new(weightFileList[ws], weight_list[ws], bias_list[ws]);
		}

        char rtname[500]="";
        char outfile[500]="";
		for (ifile = 0;ifile< cntfile;ifile++)
		{
#ifdef DEBUG
            fprintf(stderr,"%d: %s\n", ifile, filenameList[ifile]);
#endif
            rootname(filenameList[ifile], rtname);
            sprintf(outfile,"%s/%s.ss", outpath, rtname);
			FILE *fpout = 0;
            fpout=fopen(outfile,"w");
			ifp = fopen(filenameList[ifile], "r");
			if (!ifp)
				exit(1);
			seqlen = getmtx(ifp);
			fclose(ifp);
			predict_new(weight_list, bias_list, fpout);
			fclose(fpout);
		}

        for (ifile = 0; ifile<cntfile;ifile++)
        {
            free(filenameList[ifile]);
        }
        free(filenameList);

		for (ws=0;ws < numWeightFile; ws++){
			for (i=0;i<TOTAL;i++) {
				if(weight_list[ws][i]!= NULL) {
					free(weight_list[ws][i]);
				}
			}
			free(weight_list[ws]);
		}
		free(weight_list);

		for (ws=0;ws < numWeightFile; ws++){
			free(bias_list[ws]);
		}
		free(bias_list);
		

	}/*}}}*/
	for (i=0;i<TOTAL;i++) {
		if(weight[i]!= NULL) {
			free(weight[i]);
		}
	}
    
    return 0;
}
