본문 바로가기

ML & AI Theory

텐서플로의 계산 그래프 : 1.0 vs 2.x 버젼 차이

반응형

텐서플로 2.x 버전에서는 즉시 실행 모드가 기본으로 활성화되어 있기 떄문에 계산 그래프를 만들지 않고 빠르게 개발과 테스트를 할 수 있습니다. 모델을 다른 프레임워크와 공유하거나 실행 성능을 높이기 위해 그래프를 만들려면 어떻게 해야 할까요? 

 

단순한 내용으로 알아보겠습니다.

import tensorflow as tf

a = tf.constant(1)
b = tf.constant(2)
c = tf.constant(3)

z = 2*(a-b)+c
print(z.numpy())

 


tensorflow 1.x

 

텐서플로 1.x 버전은 계산 그래프를 만든 후 세션을 통해 그래프를 실행합니다. 텐서플로 1.x 버제서 계산 그래프를 만들고 실행하는 각 단곌르 자세히 정리하면 다음과 같습니다.

 

1. 비어 있는 새로운 계산 그래프를 만듭니다.

2. 계산 그래프에 노드(텐서와 연산)을 추가합니다.

3. 그래프를 실행합니다.

    a. 새로운 세션을 시작합니다.

    b. 그래프에 있는 변수를 초기화합니다.

    c. 이 세션에서 계산 그래프를 실행합니다.

 

위의 명령어에 대한 그래프를 만든다면 아래와 같습니다.

#tensor 1.x
g = tf.Graph()

#그래프에 노드를 추가합니다.
with g.as_default():
    a = tf.constant(1, name='a')
    b = tf.constant(2, name='b')
    c = tf.constant(3, name='c')
    
    z = 2*(a-b)+c
    
#그래프를 실행합니다.
with tf.compat.v1.Session(graph=g) as sess:
    print(sess.run(z))

이 코드에서는 with g.as_default()를 사용하여 그래프 g에 노드를 추가했습니다. 텐서플로 1.x에서는 명시적으로 그래프를 지정하지 않으면 항상 기본 그래프가 사용됩니다. 이 후 tf.Session을 호출하여 세션 객체를 만들고 tf.Session(graph=g)처럼 실행할 그래프를 매개변수로 전달합니다.

 

텐서플로 세션에서 그래프를 적재한 후에는 이 그래프에 있는 노드를 실행시킬 수 있습니다. 여기서 텐서와 연산을 텐서플로의 계산 그래프 안에 정의했다는 것을 기억하세요. 텐서플로 세션은 그래프에 있는 연산을 실행한 후 결과를 평가하고 추출하기 위해 사용됩니다.

 

그래프 g에 들어 있는 연산을 출력해 보겠습니다. tf.Graph()객체의 get_operation()메서드를 사용합니다.

g.get_operations()

 

이번에는 그래프 g의 정의를 출력해 보겠습니다. as_graph_def() 메서드를 호출하면 포맷팅된 문자열로 그래프 정의를 출력합니다.

g.as_graph_def()

 


tensorflow 2.x

2.x버전에서는 tf.function 데코레이터(decorator)를 사용하여 일반 파이썬 함수를 호출 가능한 그래츠 객체로 만듭니다. 마치 tf.Graph와 tf.Session을 합쳐 놓은 것처럼 생각할 수 있습니다. 앞 코드를 tf.function 데코레이터를 사용하여 다시 작성해 보겠습니다.

#tensorflow 2.x
@tf.function
def simple_func():
    a = tf.constant(1, name='a')
    b = tf.constant(2, name='b')
    c = tf.constant(3, name='c')
    
    z = 2*(a-b)+c
    return z

print(simple_func().numpy())

 

simple_func()함수에서 반환하는 텐서를 바로 numpy()메서드로 반환하여 출력했습니다. simple_func()함수는 보통 파이썬 함수처럼 호출할 수 있지만 데코레이터에 의해 객체가 바뀌었습니다.

print(simple_func.__class__)

 

<class 'tensorflow.python.eager.def_function.Function'>

 

파이썬의 다른 데코레이터처럼 다음과 같이 쓰면 simple_func()객체를 좀 더 이해하기 쉽습니다.

def simple_func():
    a = tf.constant(1, name='a')
    b = tf.constant(2, name='b')
    c = tf.constant(3, name='c')
    
    z = 2*(a-b)+c
    return z

simple_func = tf.function(simple_func)
print(simple_func().numpy())

tf.function으로 감싼 함수 안의 연산은 자동으로 텐서플로 그래프에 포함되어 실행됩니다. 이를 자동 그래프(AutoGraph)기능이라고 합니다. 

 

simple_func가 만든 그래프에 있는 연산과 그래프 정의를 확인해 보겠습니다.

con_func = simple_func.get_concrete_function()
con_func.graph.get_operations()

 

[<tf.Operation 'a' type=Const>,
 <tf.Operation 'b' type=Const>,
 <tf.Operation 'c' type=Const>,
 <tf.Operation 'sub' type=Sub>,
 <tf.Operation 'mul/x' type=Const>,
 <tf.Operation 'mul' type=Mul>,
 <tf.Operation 'add' type=AddV2>,
 <tf.Operation 'Identity' type=Identity>]

 

그래프 정의를 얻으려면 텐서플로 1.x와 마찬가지로 GRaph객체의 as_graph_def()메서드를 호출합니다.

con_func.graph.as_graph_def()

 

node {
  name: "a"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 1
      }
    }
  }
}
node {
  name: "b"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 2
      }
    }
  }
}
node {
  name: "c"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 3
      }
    }
  }
}
node {
  name: "sub"
  op: "Sub"
  input: "a"
  input: "b"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "mul/x"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 2
      }
    }
  }
}
node {
  name: "mul"
  op: "Mul"
  input: "mul/x"
  input: "sub"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "add"
  op: "AddV2"
  input: "mul"
  input: "c"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 716
}
반응형