diff --git a/losses.py b/losses.py index 7a4cc6b..ab98c95 100644 --- a/losses.py +++ b/losses.py @@ -58,10 +58,8 @@ def style_loss(grams, target_grams, style_weights): for i in xrange(num_style_layers): gram, target_gram = grams[i], target_grams[i] style_weight = style_weights[i] - _, c1, c2 = gram.get_shape().as_list() - size = c1*c2 loss = tf.reduce_sum(tf.square(gram - tf.constant(target_gram))) - loss = style_weight * loss / size + loss = style_weight * loss style_losses.append(loss) style_loss = tf.add_n(style_losses, name='style_loss') return style_loss