Skip to content

AttributeError: 'NoneType' object has no attribute 'taskgraph' #27

@Waterpine

Description

@Waterpine

Hi EPL team,

When I use epl library to train the following code:

import os
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import tensorflow as tf
import epl

def preprocess_image(image):
    # Resize and crop
    width, height = image.size
    if width > height:
        new_width = int(224 * width / height)
        image = image.resize((new_width, 224))
        left = (new_width - 224) / 2
        image = image.crop((left, 0, left + 224, 224))
    else:
        new_height = int(224 * height / width)
        image = image.resize((224, new_height))
        top = (new_height - 224) / 2
        image = image.crop((0, top, 224, top + 224))

    # Normalize pixel values
    image = np.array(image, dtype=np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406])[None, None, :]
    std = np.array([0.229, 0.224, 0.225])[None, None, :]
    image = (image - mean) / std

    return image


def load_and_preprocess_image(path):
    image = Image.open(path).convert('RGB')
    return preprocess_image(image)


train_image_dir = '/users/Master/imagenet/train'
val_image_dir = '/users/Master/imagenet/val'
class_names = sorted(os.listdir(train_image_dir))
num_classes = len(class_names)

train_image_paths = []
train_labels = []
val_image_paths = []
val_labels = []

for label, class_name in enumerate(class_names):
    train_class_dir = os.path.join(train_image_dir, class_name)
    val_class_dir = os.path.join(val_image_dir, class_name)

    for img_name in os.listdir(train_class_dir):
        img_path = os.path.join(train_class_dir, img_name)
        train_image_paths.append(img_path)
        train_labels.append(label)

    for img_name in os.listdir(val_class_dir):
        img_path = os.path.join(val_class_dir, img_name)
        val_image_paths.append(img_path)
        val_labels.append(label)


def load_images_parallel(image_paths, num_workers=16):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        images = list(executor.map(load_and_preprocess_image, image_paths))
    return np.array(images)


def load_images_chunk(image_paths, labels, batch_size):
    num_batches = int(np.ceil(len(image_paths) / batch_size))
    for i in range(num_batches):
        batch_image_paths = image_paths[i * batch_size:(i + 1) * batch_size]
        batch_labels = labels[i * batch_size:(i + 1) * batch_size]
        batch_images = load_images_parallel(batch_image_paths)
        batch_labels_one_hot = tf.keras.utils.to_categorical(batch_labels, num_classes=num_classes)
        yield batch_images, batch_labels_one_hot


def conv2d_bn(x, filters, kernel_size, strides=1, padding='same', activation=tf.nn.relu, name=None):
    x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, use_bias=False, name=name)
    x = tf.layers.batch_normalization(x, training=True)
    if activation is not None:
        x = activation(x)
    return x


def identity_block(input_tensor, filters, stage, block):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = conv2d_bn(input_tensor, filters1, 1, name=conv_name_base + '2a')
    x = conv2d_bn(x, filters2, 3, name=conv_name_base + '2b')
    x = conv2d_bn(x, filters3, 1, activation=None, name=conv_name_base + '2c')

    x = tf.add(x, input_tensor)
    x = tf.nn.relu(x)
    return x


def conv_block(input_tensor, filters, stage, block, strides=2):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = conv2d_bn(input_tensor, filters1, 1, strides=strides, name=conv_name_base + '2a')
    x = conv2d_bn(x, filters2, 3, name=conv_name_base + '2b')
    x = conv2d_bn(x, filters3, 1, activation=None, name=conv_name_base + '2c')

    shortcut = conv2d_bn(input_tensor, filters3, 1, strides=strides, activation=None, name=conv_name_base + '1')

    x = tf.add(x, shortcut)
    x = tf.nn.relu(x)
    return x


def resnet50(input_tensor, classes):
    x = conv2d_bn(input_tensor, 64, 7, strides=2, name='conv1')
    x = tf.layers.max_pooling2d(x, 3, strides=2, padding='same', name='pool1')

    x = conv_block(x, [64, 64, 256], stage=2, block='a', strides=1)
    x = identity_block(x, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, [512, 512, 2048], stage=5, block='c')

    x = tf.layers.average_pooling2d(x, 7, strides=1, padding='valid', name='pool5')
    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, classes, activation=None, name='fc1000')

    return x


def run_model():
    with tf.Session() as sess:
        input_tensor = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name="input_image")
        labels_tensor = tf.placeholder(tf.float32, shape=[None, num_classes], name="labels")
        learning_rate = 0.001

        logits = resnet50(input_tensor, num_classes)

        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels_tensor))
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss_op)

        correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels_tensor, 1))
        accuracy_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        sess.run(tf.global_variables_initializer())

        epochs = 10
        batch_size = 64

        for epoch in range(epochs):
            step = 0

            for batch_images, batch_labels_one_hot in load_images_chunk(train_image_paths, train_labels, batch_size):
                _, loss, accuracy = sess.run(
                    [train_op, loss_op, accuracy_op],
                    feed_dict={input_tensor: batch_images, labels_tensor: batch_labels_one_hot}
                )
                print(f"Epoch {epoch + 1}/{epochs}, Step: {step}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
                step = step + 1

            # Validate the model
            val_accuracy_list = []
            for batch_images, batch_labels_one_hot in load_images_chunk(val_image_paths, val_labels, batch_size):
                accuracy = sess.run(accuracy_op,
                                    feed_dict={input_tensor: batch_images, labels_tensor: batch_labels_one_hot})
                val_accuracy_list.append(accuracy)
            val_accuracy = np.mean(val_accuracy_list)
            print(f"Validation Accuracy: {val_accuracy:.4f}")


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    config_json = {}
    epl.init(epl.Config(config_json))
    print(epl.Env.get().cluster.gpu_num_per_worker)
    if epl.Env.get().cluster.gpu_num_per_worker > 1:
        # Avoid NCCL hang.
        os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
    epl.set_default_strategy(epl.replicate(device_count=1))
    run_model()

I am confronted with the following issue:
Traceback (most recent call last):
File "resnet50_split3.py", line 203, in
run_model()
File "resnet50_split3.py", line 164, in run_model
sess.run(tf.global_variables_initializer())
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 453, in run
assign_ops = _init_local_resources(self, fn)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 416, in _init_local_resources
assign_ops = broadcast_variables()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 339, in broadcast_variables
bcast_variables = taskgraph.get_variables(replica_idx)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/ir/taskgraph.py", line 409, in get_variables
if id(var_tensor.taskgraph) != id(self):
AttributeError: 'NoneType' object has no attribute 'taskgraph'

Could you give me a hand when you are free? Thank you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions