/* NiuTrans.Tensor - an open-source tensor library
 * Copyright (C) 2018, Natural Language Processing Lab, Northestern University.
 * All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * $Created by: XIAO Tong (xiaotong@mail.neu.edu.cn) 2018-07-04
 */

#include <stdio.h>
#include "XGlobal.h"
#include "XTensor.h"

#ifndef __XLINK_H__
#define __XLINK_H__

#include "XGlobal.h"

namespace nts{ // namespace nts(NiuTrans.Tensor)

/* cross reference */
struct XTensor;

#define MAX_OP_NAME_LENGTH 16
#define PARAM_UNTI_SIZE    64

/*
This defines the link among tensors in networks. XLink can be
cast as a hyperedge in a graph. when we compute on tensors, we actually create a
network where nodes are tensors and edges the connections among them. Each connection is
a hyperedge whose head is the output tensor and tails are input tensors. E.g,
c = a + b
represents a network with three nodes (a, b and c) and a hyperedge that links a and b (tails) to c (head).
 
   + (=c)
  / \
 a   b
 
for c, we have a incoming edge (a, b) -> c
for a, we also have a edge c -> a in the reverse order (in a view of acyclic directed graphs)
*/
struct XLink
{
    /* head of the hyperedge */
    XTensor *  head;

    /* tails of the hyperedge */
    XTensor ** tails;

    /* number of tails */
    int tailNum;

    /* parameters used. e.g., c = a * b * \alpha 
       scalar \alpha is the parameter */
    void * params;

    /* number of parameters */
    int paramNum;

    /* size of each parameter */
    static int paramSize;

    /* name of the hyperedge type. e.g., sum, mul ... */
    char type[MAX_OP_NAME_LENGTH];
    
    /* type id */
    int typeID;

    /* caculator (pointer to the class for computation) */
    void * caculator;
    
    /* constuctor */
    XLink();
    
    /* deconstructor */
    ~XLink();

    /* reset it */
    void Reset();

    /* clear it */
    void Clear();

    /* clear tails */
    void ClearTail();

    /* clear the incoming node list of tensor node */
    static
    void ClearIncoming(XTensor * node);
    
    /* clear the outgoing node list of tensor node */
    static
    void ClearOutgoing(XTensor * node);

    /* set edge type id and name */
    void SetType(int id);

    /* set head */
    void SetHead(XTensor * h);

    /* add a tail */
    void AddTail(XTensor * t);

    /* add two tails in one time */
    void AddTwoTails(XTensor * t1, XTensor * t2);

    /* add a parameter in default type */
    void AddParam(DTYPE param);

    /* add a parameter */
    void AddParam(void * param, int size);

    /* get a paramter in default type */
    DTYPE GetParam(int i);

    /* get a paramter in integer */
    int GetParamInt(int i);
    
    /* get a parameter in MATRIX_TRANS_TYPE */
    MATRIX_TRANS_TYPE GetParamTrans(int i);

    /* create a hyper edge with two input tensors and a output tensor */
    static
    void MakeLink(const XTensor * t1, const XTensor * t2, XTensor * h, int id);

    /* create a hyper edge with a list of input tensors and a output tensor */
    static
    void MakeLink(const XList * list, XTensor * h, int id);

    /* create a hyper edge with a input tensors and a list of output tensors */
    static
    void MakeLink(XTensor * h, XList * list, int id);

    /* add a parameter */
    static
    void AddParamToHead(XTensor * h, DTYPE param);

    /* add an integer parameter */
    static
    void AddParamToHeadInt(XTensor * h, int param);

    /* add a MATRIX_TRANS_TYPE parameter */
    static
    void AddParamToHeadTrans(XTensor * h, MATRIX_TRANS_TYPE param);

    /* add a boolean parameter */
    static
    void AddParamToHeadBool(XTensor * h, bool param);

    /* add a pointer parameter */
    static
    void AddParamToHeadPointer(XTensor * h, void * param);

    /* replace a node with another, i.e., we redirect the links to the new node */
    static 
    void Replace(const XTensor * oldOne, XTensor * newOne);

    /* copy links of a given node */
    static
    void CopyIncoming(const XTensor * reference, XTensor * target);

    /* check the correctness of the network encoded in a root node (tensor) */
    static
    void CheckNetwork(XTensor * root);

    /* show the network encoded in a root node (tensor) */
    static
    void ShowNetwork(FILE * file, XTensor * root);
};
    
} // namespace nts(NiuTrans.Tensor)

#endif // __XLINK_H__