Initial model (broken)
This commit is contained in:
commit
84712a1740
155
model.py
Normal file
155
model.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
import datetime
|
||||
import csv
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from keras import *
|
||||
from keras.layers import *
|
||||
from keras.backend import *
|
||||
from keras.optimizers import *
|
||||
from keras.preprocessing.image import *
|
||||
from keras.callbacks import TensorBoard
|
||||
|
||||
class SimpleCNN(object):
|
||||
def __init__(self):
|
||||
self.image_size = 512
|
||||
self.classes = 6
|
||||
self.squares = 8
|
||||
self.boxes = 2
|
||||
self.build_model()
|
||||
|
||||
def iou(self, box_one, box_two):
|
||||
def op_func(combine, compute):
|
||||
return combine(compute(box_one), compute(box_two))
|
||||
|
||||
i_x1 = op_func(maximum, lambda b: b[0]-b[2]/2)
|
||||
i_y1 = op_func(maximum, lambda b: b[1]-b[3]/2)
|
||||
i_x2 = op_func(minimum, lambda b: b[0]+b[2]/2)
|
||||
i_y2 = op_func(minimum, lambda b: b[1]+b[3]/2)
|
||||
|
||||
area_1 = box_one[2]*box_one[3]
|
||||
area_2 = box_two[2]*box_one[3]
|
||||
intersection_area = (i_x2-i_x1)*(i_y2-i_y1)
|
||||
|
||||
return intersection_area/(area_1+area_2-intersection_area+.01)
|
||||
|
||||
def cost(self, truth_tensor, output_tensor):
|
||||
def input_output_tensor(f):
|
||||
def per_output_tensor(output):
|
||||
return tf.convert_to_tensor(
|
||||
[tf.map_fn(
|
||||
lambda truth: f(output[i*5:i*5+5],truth[0:5]),
|
||||
truth_tensor)
|
||||
for i in range(2)])
|
||||
|
||||
return tf.map_fn(per_output_tensor, output_tensor)
|
||||
|
||||
# Compute per object IOU values for each square, for each box.
|
||||
ious = input_output_tensor(self.iou)
|
||||
|
||||
# Compute the minimum IOS per object.
|
||||
min_class_ious = min(min(ious, axis=0), axis=0)
|
||||
|
||||
# Compute the minimum IOS per object.
|
||||
max_class_ious = max(max(ious, axis=0), axis=0)
|
||||
|
||||
# Whether each box of each square is responsible for
|
||||
# the minimum IOU. This is used for penalizing object absense.
|
||||
eq_min_box = tf.map_fn(lambda iou:
|
||||
tf.convert_to_tensor([equal(iou[j], min_class_ious) for j in range(2)]), ious, dtype='bool')
|
||||
# Same as above, but per-square rather than per-box.
|
||||
eq_min_square= any(eq_min_box, axis=1)
|
||||
|
||||
# Whether each box of each square is responsible
|
||||
# for the maximum IOU. This is used for penalizing
|
||||
# incorrect bounds and confidence.
|
||||
eq_max_box = tf.map_fn(lambda iou:
|
||||
tf.convert_to_tensor([equal(iou[j], max_class_ious) for j in range(2)]), ious, dtype='bool')
|
||||
# Same as above, but per-square. Penalizes bad class guesses.
|
||||
eq_max_square= any(eq_max_box, axis=1)
|
||||
|
||||
# The cost of incorrect coordinate guesses per box.
|
||||
coord_cost = input_output_tensor(
|
||||
lambda o,t: pow(o[0]-t[0], 2) + pow(o[1]-t[1], 2))
|
||||
# The cost of incorrect size guesses per box.
|
||||
dim_cost = input_output_tensor(
|
||||
lambda o,t: pow(pow(o[0], 0.5)-pow(t[0], 0.5), 2) + pow(pow(o[1], 0.5)-pow(t[1], 0.5), 2))
|
||||
# The cost of incorrect confidence guesses per box.
|
||||
confidence_cost = input_output_tensor(
|
||||
lambda o,t: pow(o[4]-t[4], 2))
|
||||
# The cost of incorrect class guesses, per square.
|
||||
class_cost = tf.map_fn(lambda output:
|
||||
tf.map_fn(lambda truth:
|
||||
tf.norm(output[2*5:2*5+6]-truth[5:12]), truth_tensor), output_tensor)
|
||||
|
||||
# Weights from the YOLO paper.
|
||||
coord_weight = 5
|
||||
obj_weight = 1
|
||||
noobj_weight = 0.5
|
||||
|
||||
# Cost, per box, penalized when an object is guessed.
|
||||
obj_cost= coord_weight * (coord_cost + dim_cost) + obj_weight * confidence_cost
|
||||
# Cost, per box, penalized when an object is not guessed.
|
||||
noobj_cost= noobj_weight * confidence_cost
|
||||
# Cost per box, selecting only "responsible" entries.
|
||||
box_cost = (
|
||||
obj_cost* cast(eq_max_box, 'float32') +
|
||||
noobj_cost* cast(eq_min_box, 'float32')
|
||||
)
|
||||
|
||||
# Cost per square, penalizing only "responsible" squares.
|
||||
square_cost= class_cost * cast(eq_max_square, 'float32')
|
||||
|
||||
# Total cost
|
||||
cost = sum(sum(sum(box_cost))) + sum(sum(square_cost))
|
||||
return cost
|
||||
|
||||
def build_model(self):
|
||||
self.model = Sequential()
|
||||
self.convolutional_layer(
|
||||
filters=16, kernel_size=7, strides=2, padding='same',
|
||||
input_shape=(self.image_size,self.image_size,3)
|
||||
)
|
||||
self.maxpool_layer(pool_size=2, strides=2, padding='same')
|
||||
|
||||
for _ in range(4):
|
||||
self.convolutional_layer(filters=8, kernel_size=3, padding='same')
|
||||
self.maxpool_layer(pool_size=2, strides=2, padding='same')
|
||||
|
||||
self.model.add(Flatten())
|
||||
self.model.add(Dense(units=self.squares*self.squares*(self.boxes*5+self.classes)))
|
||||
self.model.add(Reshape((64,-1)))
|
||||
self.model.compile(loss=self.loss, optimizer=Adam(lr=0.5e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0))
|
||||
|
||||
def maxpool_layer(self, *args, **kwargs):
|
||||
self.model.add(MaxPooling2D(*args, **kwargs))
|
||||
|
||||
def convolutional_layer(self, *args, **kwargs):
|
||||
self.model.add(Conv2D(*args, **kwargs))
|
||||
self.model.add(LeakyReLU())
|
||||
|
||||
def summary(self):
|
||||
self.model.summary()
|
||||
|
||||
def image_generator(source, n):
|
||||
for i in range(n):
|
||||
image_path = os.path.join(source, 'image' + str(i) + '.png')
|
||||
truth_path = os.path.join(source, 'objects' + str(i) + '.csv')
|
||||
data = img_to_array(load_img(image_path))
|
||||
truth_file = open(truth_path, "r")
|
||||
csv_reader = csv.reader(truth_file, delimiter=',')
|
||||
truth = tf.convert_to_tensor([[float(x) for x in row] for row in csv_reader])
|
||||
truth_file.close
|
||||
yield data, truth
|
||||
|
||||
inputs = []
|
||||
outputs = []
|
||||
for (inp, out) in image_generator('data/test', 100):
|
||||
inputs.append(inp)
|
||||
outputs.append(out)
|
||||
|
||||
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
|
||||
|
||||
model = SimpleCNN()
|
||||
model.summary()
|
||||
model.model.fit(x=tf.convert_to_tensor(inputs), y=tf.convert_to_tensor(outputs), epochs=1, steps_per_epoch=1)
|
Loading…
Reference in New Issue
Block a user