Skip to content

Commit

Permalink
Remove create_serving_input_fn and export_strategies from lab to matc…
Browse files Browse the repository at this point in the history
…h the solution

Fixes GoogleCloudPlatform#2441
  • Loading branch information
MrCsabaToth committed Sep 13, 2023
1 parent 8a72c6a commit a36bec2
Showing 1 changed file with 1 addition and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,45 +132,6 @@ def batch_predict(args):
for best_items_for_user in topk.eval():
f.write(",".join(str(x) for x in best_items_for_user) + '\n')

# Online prediction returns row and column factors as needed
def create_serving_input_fn(args):
def for_user_embeddings(userId):
# All items for this user (for user_embeddings)
items = tf.range(args["nitems"], dtype = tf.int64)
users = userId * tf.ones(shape = [args["nitems"]], dtype = tf.int64)
ratings = 0.1 * tf.ones_like(tensor = users, dtype = tf.float32)
return items, users, ratings, tf.constant(value = True, dtype = tf.bool)

def for_item_embeddings(itemId):
# All users for this item (for item_embeddings)
users = tf.range(args["nusers"], dtype = tf.int64)
items = itemId * tf.ones(shape = [args["nusers"]], dtype = tf.int64)
ratings = 0.1 * tf.ones_like(tensor = users, dtype = tf.float32)
return items, users, ratings, tf.constant(value = False, dtype = tf.bool)

def serving_input_fn():
feature_ph = {
"userId": tf.placeholder(dtype = tf.int64, shape = 1),
"itemId": tf.placeholder(dtype = tf.int64, shape = 1)
}

(items, users, ratings, project_row) = \
tf.cond(pred = feature_ph["userId"][0] < tf.constant(value = 0, dtype = tf.int64),
true_fn = lambda: for_item_embeddings(feature_ph["itemId"]),
false_fn = lambda: for_user_embeddings(feature_ph["userId"]))
rows = tf.stack(values = [users, items], axis = 1)
cols = tf.stack(values = [items, users], axis = 1)
input_rows = tf.SparseTensor(indices = rows, values = ratings, dense_shape = (args["nusers"], args["nitems"]))
input_cols = tf.SparseTensor(indices = cols, values = ratings, dense_shape = (args["nusers"], args["nitems"]))

features = {
WALSMatrixFactorization.INPUT_ROWS: input_rows,
WALSMatrixFactorization.INPUT_COLS: input_cols,
WALSMatrixFactorization.PROJECT_ROW: project_row
}
return tf.contrib.learn.InputFnOps(features = features, labels = None, default_inputs = feature_ph)
return serving_input_fn

def train_and_evaluate(args):
train_steps = int(0.5 + (1.0 * args["num_epochs"] * args["nusers"]) / args["batch_size"])
steps_in_epoch = int(0.5 + args["nusers"] / args["batch_size"])
Expand All @@ -186,8 +147,7 @@ def experiment_fn(output_dir):
eval_input_fn = read_dataset(tf.estimator.ModeKeys.EVAL, args),
train_steps = train_steps,
eval_steps = 1,
min_eval_frequency = steps_in_epoch,
export_strategies = tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(serving_input_fn = create_serving_input_fn(args))
min_eval_frequency = steps_in_epoch
)

from tensorflow.contrib.learn.python.learn import learn_runner
Expand Down

0 comments on commit a36bec2

Please sign in to comment.