{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import array\n",
    "import gzip\n",
    "import random\n",
    "from tensorflow.keras import Model\n",
    "from collections import defaultdict\n",
    "import dateutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse(path):\n",
    "    g = gzip.open(path, 'r')\n",
    "    for l in g:\n",
    "        yield eval(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "userIDs = {}\n",
    "itemIDs = {}\n",
    "interactions = []\n",
    "interactionsPerUser = defaultdict(list)\n",
    "\n",
    "for d in parse(\"goodreads_reviews_comics_graphic.json.gz\"):\n",
    "    u = d['user_id']\n",
    "    i = d['book_id']\n",
    "    t = d['date_added']\n",
    "    r = d['rating']\n",
    "    dt = dateutil.parser.parse(t)\n",
    "    t = int(dt.timestamp())\n",
    "    if not u in userIDs: userIDs[u] = len(userIDs)\n",
    "    if not i in itemIDs: itemIDs[i] = len(itemIDs)\n",
    "    interactions.append((t,u,i,r))\n",
    "    interactionsPerUser[u].append((t,i,r))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Wed Apr 03 10:10:41 -0700 2013'"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d['date_added']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "dt = dateutil.parser.parse(d['date_added'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on datetime object:\n",
      "\n",
      "class datetime(date)\n",
      " |  datetime(year, month, day[, hour[, minute[, second[, microsecond[,tzinfo]]]]])\n",
      " |  \n",
      " |  The year, month and day arguments are required. tzinfo may be None, or an\n",
      " |  instance of a tzinfo subclass. The remaining arguments may be ints.\n",
      " |  \n",
      " |  Method resolution order:\n",
      " |      datetime\n",
      " |      date\n",
      " |      builtins.object\n",
      " |  \n",
      " |  Methods defined here:\n",
      " |  \n",
      " |  __add__(self, value, /)\n",
      " |      Return self+value.\n",
      " |  \n",
      " |  __eq__(self, value, /)\n",
      " |      Return self==value.\n",
      " |  \n",
      " |  __ge__(self, value, /)\n",
      " |      Return self>=value.\n",
      " |  \n",
      " |  __getattribute__(self, name, /)\n",
      " |      Return getattr(self, name).\n",
      " |  \n",
      " |  __gt__(self, value, /)\n",
      " |      Return self>value.\n",
      " |  \n",
      " |  __hash__(self, /)\n",
      " |      Return hash(self).\n",
      " |  \n",
      " |  __le__(self, value, /)\n",
      " |      Return self<=value.\n",
      " |  \n",
      " |  __lt__(self, value, /)\n",
      " |      Return self<value.\n",
      " |  \n",
      " |  __ne__(self, value, /)\n",
      " |      Return self!=value.\n",
      " |  \n",
      " |  __new__(*args, **kwargs) from builtins.type\n",
      " |      Create and return a new object.  See help(type) for accurate signature.\n",
      " |  \n",
      " |  __radd__(self, value, /)\n",
      " |      Return value+self.\n",
      " |  \n",
      " |  __reduce__(...)\n",
      " |      __reduce__() -> (cls, state)\n",
      " |  \n",
      " |  __repr__(self, /)\n",
      " |      Return repr(self).\n",
      " |  \n",
      " |  __rsub__(self, value, /)\n",
      " |      Return value-self.\n",
      " |  \n",
      " |  __str__(self, /)\n",
      " |      Return str(self).\n",
      " |  \n",
      " |  __sub__(self, value, /)\n",
      " |      Return self-value.\n",
      " |  \n",
      " |  astimezone(...)\n",
      " |      tz -> convert to local time in new timezone tz\n",
      " |  \n",
      " |  combine(...) from builtins.type\n",
      " |      date, time -> datetime with same date and time fields\n",
      " |  \n",
      " |  ctime(...)\n",
      " |      Return ctime() style string.\n",
      " |  \n",
      " |  date(...)\n",
      " |      Return date object with same year, month and day.\n",
      " |  \n",
      " |  dst(...)\n",
      " |      Return self.tzinfo.dst(self).\n",
      " |  \n",
      " |  fromtimestamp(...) from builtins.type\n",
      " |      timestamp[, tz] -> tz's local time from POSIX timestamp.\n",
      " |  \n",
      " |  isoformat(...)\n",
      " |      [sep] -> string in ISO 8601 format, YYYY-MM-DDTHH:MM:SS[.mmmmmm][+HH:MM].\n",
      " |      \n",
      " |      sep is used to separate the year from the time, and defaults to 'T'.\n",
      " |  \n",
      " |  now(tz=None) from builtins.type\n",
      " |      Returns new datetime object representing current time local to tz.\n",
      " |      \n",
      " |        tz\n",
      " |          Timezone object.\n",
      " |      \n",
      " |      If no tz is specified, uses local timezone.\n",
      " |  \n",
      " |  replace(...)\n",
      " |      Return datetime with new specified fields.\n",
      " |  \n",
      " |  strptime(...) from builtins.type\n",
      " |      string, format -> new datetime parsed from a string (like time.strptime()).\n",
      " |  \n",
      " |  time(...)\n",
      " |      Return time object with same time but with tzinfo=None.\n",
      " |  \n",
      " |  timestamp(...)\n",
      " |      Return POSIX timestamp as float.\n",
      " |  \n",
      " |  timetuple(...)\n",
      " |      Return time tuple, compatible with time.localtime().\n",
      " |  \n",
      " |  timetz(...)\n",
      " |      Return time object with same time and tzinfo.\n",
      " |  \n",
      " |  tzname(...)\n",
      " |      Return self.tzinfo.tzname(self).\n",
      " |  \n",
      " |  utcfromtimestamp(...) from builtins.type\n",
      " |      Construct a naive UTC datetime from a POSIX timestamp.\n",
      " |  \n",
      " |  utcnow(...) from builtins.type\n",
      " |      Return a new datetime representing UTC day and time.\n",
      " |  \n",
      " |  utcoffset(...)\n",
      " |      Return self.tzinfo.utcoffset(self).\n",
      " |  \n",
      " |  utctimetuple(...)\n",
      " |      Return UTC time tuple, compatible with time.localtime().\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors defined here:\n",
      " |  \n",
      " |  hour\n",
      " |  \n",
      " |  microsecond\n",
      " |  \n",
      " |  minute\n",
      " |  \n",
      " |  second\n",
      " |  \n",
      " |  tzinfo\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data and other attributes defined here:\n",
      " |  \n",
      " |  max = datetime.datetime(9999, 12, 31, 23, 59, 59, 999999)\n",
      " |  \n",
      " |  min = datetime.datetime(1, 1, 1, 0, 0)\n",
      " |  \n",
      " |  resolution = datetime.timedelta(0, 0, 1)\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Methods inherited from date:\n",
      " |  \n",
      " |  __format__(...)\n",
      " |      Formats self with strftime.\n",
      " |  \n",
      " |  fromordinal(...) from builtins.type\n",
      " |      int -> date corresponding to a proleptic Gregorian ordinal.\n",
      " |  \n",
      " |  isocalendar(...)\n",
      " |      Return a 3-tuple containing ISO year, week number, and weekday.\n",
      " |  \n",
      " |  isoweekday(...)\n",
      " |      Return the day of the week represented by the date.\n",
      " |      Monday == 1 ... Sunday == 7\n",
      " |  \n",
      " |  strftime(...)\n",
      " |      format -> strftime() style string.\n",
      " |  \n",
      " |  today(...) from builtins.type\n",
      " |      Current date or datetime:  same as self.__class__.fromtimestamp(time.time()).\n",
      " |  \n",
      " |  toordinal(...)\n",
      " |      Return proleptic Gregorian ordinal.  January 1 of year 1 is day 1.\n",
      " |  \n",
      " |  weekday(...)\n",
      " |      Return the day of the week represented by the date.\n",
      " |      Monday == 0 ... Sunday == 6\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors inherited from date:\n",
      " |  \n",
      " |  day\n",
      " |  \n",
      " |  month\n",
      " |  \n",
      " |  year\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1365009041.0"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt.timestamp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "542338"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(interactions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Be careful building train/test splits after sorting!\n",
    "interactions.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "itemIDs['dummy'] = len(itemIDs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "interactionsWithPrevious = []\n",
    "\n",
    "for u in interactionsPerUser:\n",
    "    interactionsPerUser[u].sort()\n",
    "    lastItem = 'dummy'\n",
    "    for (t,i,r) in interactionsPerUser[u]:\n",
    "        interactionsWithPrevious.append((t,u,i,lastItem,r))\n",
    "        lastItem = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "itemsPerUser = defaultdict(set)\n",
    "for _,u,i,_ in interactions:\n",
    "    itemsPerUser[u].add(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "items = list(itemIDs.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam(0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FPMC(tf.keras.Model):\n",
    "    def __init__(self, K, lamb):\n",
    "        super(FPMC, self).__init__()\n",
    "        # Initialize variables\n",
    "        self.betaI = tf.Variable(tf.random.normal([len(itemIDs)],stddev=0.001))\n",
    "        self.gammaUI = tf.Variable(tf.random.normal([len(userIDs),K],stddev=0.001))\n",
    "        self.gammaIU = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))\n",
    "        self.gammaIJ = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))\n",
    "        self.gammaJI = tf.Variable(tf.random.normal([len(itemIDs),K],stddev=0.001))\n",
    "        # Regularization coefficient\n",
    "        self.lamb = lamb\n",
    "\n",
    "    # Prediction for a single instance\n",
    "    def predict(self, u, i, j):\n",
    "        p = self.betaI[i] + tf.tensordot(self.gammaUI[u], self.gammaIU[i], 1) +\\\n",
    "                            tf.tensordot(self.gammaIJ[i], self.gammaJI[j], 1)\n",
    "        return p\n",
    "\n",
    "    # Regularizer\n",
    "    def reg(self):\n",
    "        return self.lamb * tf.nn.l2_loss(self.betaI) +\\\n",
    "                           tf.nn.l2_loss(self.gammaUI) +\\\n",
    "                           tf.nn.l2_loss(self.gammaIU) +\\\n",
    "                           tf.nn.l2_loss(self.gammaIJ) +\\\n",
    "                           tf.nn.l2_loss(self.gammaJI)\n",
    "\n",
    "    def call(self, sampleU, # user\n",
    "                   sampleI, # item\n",
    "                   sampleJ, # previous item\n",
    "                   sampleK): # negative item\n",
    "        u = tf.convert_to_tensor(sampleU, dtype=tf.int32)\n",
    "        i = tf.convert_to_tensor(sampleI, dtype=tf.int32)\n",
    "        j = tf.convert_to_tensor(sampleJ, dtype=tf.int32)\n",
    "        k = tf.convert_to_tensor(sampleK, dtype=tf.int32)\n",
    "        gamma_ui = tf.nn.embedding_lookup(self.gammaUI, u)\n",
    "        gamma_iu = tf.nn.embedding_lookup(self.gammaIU, i)\n",
    "        gamma_ij = tf.nn.embedding_lookup(self.gammaIJ, i)\n",
    "        gamma_ji = tf.nn.embedding_lookup(self.gammaJI, j)\n",
    "        beta_i = tf.nn.embedding_lookup(self.betaI, i)\n",
    "        x_uij = beta_i + tf.reduce_sum(tf.multiply(gamma_ui, gamma_iu), 1) +\\\n",
    "                         tf.reduce_sum(tf.multiply(gamma_ij, gamma_ji), 1)\n",
    "        gamma_uk = tf.nn.embedding_lookup(self.gammaUI, u)\n",
    "        gamma_ku = tf.nn.embedding_lookup(self.gammaIU, k)\n",
    "        gamma_kj = tf.nn.embedding_lookup(self.gammaIJ, k)\n",
    "        gamma_jk = tf.nn.embedding_lookup(self.gammaJI, j)\n",
    "        beta_k = tf.nn.embedding_lookup(self.betaI, k)\n",
    "        x_ukj = beta_k + tf.reduce_sum(tf.multiply(gamma_uk, gamma_ku), 1) +\\\n",
    "                         tf.reduce_sum(tf.multiply(gamma_kj, gamma_jk), 1)\n",
    "        return -tf.reduce_mean(tf.math.log(tf.math.sigmoid(x_uij - x_ukj)))\n",
    "\n",
    "model = FPMC(5, 0.00001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trainingStep(interactions):\n",
    "    with tf.GradientTape() as tape:\n",
    "        sampleU, sampleI, sampleJ, sampleK = [], [], [], []\n",
    "        for _ in range(100000):\n",
    "            _,u,i,j,_ = random.choice(interactionsWithPrevious) # positive sample\n",
    "            k = random.choice(items) # negative sample\n",
    "            while k in itemsPerUser[u]:\n",
    "                k = random.choice(items)\n",
    "            sampleU.append(userIDs[u])\n",
    "            sampleI.append(itemIDs[i])\n",
    "            sampleJ.append(itemIDs[j])\n",
    "            sampleK.append(itemIDs[k])\n",
    "\n",
    "        loss = model(sampleU,sampleI,sampleJ,sampleK)\n",
    "        loss += model.reg()\n",
    "    gradients = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients((grad, var) for\n",
    "                              (grad, var) in zip(gradients, model.trainable_variables)\n",
    "                              if grad is not None)\n",
    "    return loss.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 0, objective = 1.511279582977295\n",
      "iteration 1, objective = 1.2588390707969666\n",
      "iteration 2, objective = 1.1297369996706645\n",
      "iteration 3, objective = 1.083426907658577\n",
      "iteration 4, objective = 1.0654948353767395\n",
      "iteration 5, objective = 1.0426013966401417\n",
      "iteration 6, objective = 1.0136300751141138\n",
      "iteration 7, objective = 0.9877421036362648\n",
      "iteration 8, objective = 0.9688891900910271\n",
      "iteration 9, objective = 0.9537673056125641\n",
      "iteration 10, objective = 0.9389547543092207\n",
      "iteration 11, objective = 0.924468070268631\n",
      "iteration 12, objective = 0.9119442242842454\n",
      "iteration 13, objective = 0.9017925603049142\n",
      "iteration 14, objective = 0.8928021550178528\n",
      "iteration 15, objective = 0.8838572315871716\n",
      "iteration 16, objective = 0.8749939413631663\n",
      "iteration 17, objective = 0.8668547239568498\n",
      "iteration 18, objective = 0.8596962125677812\n",
      "iteration 19, objective = 0.8531941503286362\n",
      "iteration 20, objective = 0.8469671947615487\n",
      "iteration 21, objective = 0.8409658778797496\n",
      "iteration 22, objective = 0.8353376103484113\n",
      "iteration 23, objective = 0.8301393563548723\n",
      "iteration 24, objective = 0.8252876377105713\n",
      "iteration 25, objective = 0.8206844834180979\n",
      "iteration 26, objective = 0.8163018160396152\n",
      "iteration 27, objective = 0.8121438920497894\n",
      "iteration 28, objective = 0.808196719350486\n",
      "iteration 29, objective = 0.8044517914454142\n",
      "iteration 30, objective = 0.8009111785119579\n",
      "iteration 31, objective = 0.7975556664168835\n",
      "iteration 32, objective = 0.7943431048682241\n",
      "iteration 33, objective = 0.7912541680476245\n",
      "iteration 34, objective = 0.7883054665156773\n",
      "iteration 35, objective = 0.7855102452966902\n",
      "iteration 36, objective = 0.7828441987166533\n",
      "iteration 37, objective = 0.7802768381018388\n",
      "iteration 38, objective = 0.7777983790788895\n",
      "iteration 39, objective = 0.7754291370511055\n",
      "iteration 40, objective = 0.7731650108244361\n",
      "iteration 41, objective = 0.7709876270521254\n",
      "iteration 42, objective = 0.7688826294832452\n",
      "iteration 43, objective = 0.766856079751795\n",
      "iteration 44, objective = 0.7649076700210571\n",
      "iteration 45, objective = 0.7630311548709869\n",
      "iteration 46, objective = 0.7612194609134755\n",
      "iteration 47, objective = 0.7594666803876559\n",
      "iteration 48, objective = 0.7577734066515552\n",
      "iteration 49, objective = 0.7561364960670471\n",
      "iteration 50, objective = 0.7545543885698506\n",
      "iteration 51, objective = 0.7530237344595102\n",
      "iteration 52, objective = 0.7515405641411835\n",
      "iteration 53, objective = 0.7500993178950416\n",
      "iteration 54, objective = 0.7487024188041687\n",
      "iteration 55, objective = 0.7473486268094608\n",
      "iteration 56, objective = 0.74603404601415\n",
      "iteration 57, objective = 0.7447568270666846\n",
      "iteration 58, objective = 0.7435138235657902\n",
      "iteration 59, objective = 0.7423051019509633\n",
      "iteration 60, objective = 0.7411307174651349\n",
      "iteration 61, objective = 0.7399865581143287\n",
      "iteration 62, objective = 0.738871072965955\n",
      "iteration 63, objective = 0.737782571464777\n",
      "iteration 64, objective = 0.7367215339954083\n",
      "iteration 65, objective = 0.7356885843204729\n",
      "iteration 66, objective = 0.7346795423706965\n",
      "iteration 67, objective = 0.7336928073097678\n",
      "iteration 68, objective = 0.7327290160068567\n",
      "iteration 69, objective = 0.7317889622279576\n",
      "iteration 70, objective = 0.7308707564649447\n",
      "iteration 71, objective = 0.7299708525339762\n",
      "iteration 72, objective = 0.7290887481545749\n",
      "iteration 73, objective = 0.7282281139412442\n",
      "iteration 74, objective = 0.7273833179473876\n",
      "iteration 75, objective = 0.7265568673610687\n",
      "iteration 76, objective = 0.7257468832003606\n",
      "iteration 77, objective = 0.7249520092438428\n",
      "iteration 78, objective = 0.724172756641726\n",
      "iteration 79, objective = 0.7234086379408836\n",
      "iteration 80, objective = 0.7226598961853686\n",
      "iteration 81, objective = 0.7219269515537634\n",
      "iteration 82, objective = 0.721204262181937\n",
      "iteration 83, objective = 0.7204953460466295\n",
      "iteration 84, objective = 0.7197979148696451\n",
      "iteration 85, objective = 0.7191143444804258\n",
      "iteration 86, objective = 0.718441216424964\n",
      "iteration 87, objective = 0.7177786888046698\n",
      "iteration 88, objective = 0.7171292304992676\n",
      "iteration 89, objective = 0.7164888428317175\n",
      "iteration 90, objective = 0.7158573584242182\n",
      "iteration 91, objective = 0.7152379433745923\n",
      "iteration 92, objective = 0.7146299423709992\n",
      "iteration 93, objective = 0.7140273623009945\n",
      "iteration 94, objective = 0.7134352527166667\n",
      "iteration 95, objective = 0.7128530777990818\n",
      "iteration 96, objective = 0.7122775512872282\n",
      "iteration 97, objective = 0.7117099384872281\n",
      "iteration 98, objective = 0.7111516022923017\n",
      "iteration 99, objective = 0.7106005704402923\n",
      "iteration 100, objective = 0.7100562264423559\n",
      "iteration 101, objective = 0.7095218312506583\n",
      "iteration 102, objective = 0.7089926154868117\n",
      "iteration 103, objective = 0.7084703869544543\n",
      "iteration 104, objective = 0.7079550413858323\n",
      "iteration 105, objective = 0.7074466342071317\n",
      "iteration 106, objective = 0.7069433504175917\n",
      "iteration 107, objective = 0.7064493022583149\n",
      "iteration 108, objective = 0.7059602535099064\n",
      "iteration 109, objective = 0.7054755774411288\n",
      "iteration 110, objective = 0.7049960359796748\n",
      "iteration 111, objective = 0.7045255745095866\n",
      "iteration 112, objective = 0.7040586413535397\n",
      "iteration 113, objective = 0.7035957538245017\n",
      "iteration 114, objective = 0.7031366690345432\n",
      "iteration 115, objective = 0.702685164994207\n",
      "iteration 116, objective = 0.7022385225336776\n",
      "iteration 117, objective = 0.70179562598972\n",
      "iteration 118, objective = 0.7013574177477541\n",
      "iteration 119, objective = 0.700923993686835\n",
      "iteration 120, objective = 0.7004959647320519\n",
      "iteration 121, objective = 0.7000723280867592\n",
      "iteration 122, objective = 0.6996519483201872\n",
      "iteration 123, objective = 0.6992367203197172\n",
      "iteration 124, objective = 0.6988259835243225\n",
      "iteration 125, objective = 0.6984184086322784\n",
      "iteration 126, objective = 0.6980149830420186\n",
      "iteration 127, objective = 0.6976144746877253\n",
      "iteration 128, objective = 0.697217395139295\n",
      "iteration 129, objective = 0.6968239830090449\n",
      "iteration 130, objective = 0.6964337739325662\n",
      "iteration 131, objective = 0.696048870231166\n",
      "iteration 132, objective = 0.6956687273835778\n",
      "iteration 133, objective = 0.6952885780761491\n",
      "iteration 134, objective = 0.6949125709357085\n",
      "iteration 135, objective = 0.6945386883967063\n",
      "iteration 136, objective = 0.6941698994079646\n",
      "iteration 137, objective = 0.6938040865504224\n",
      "iteration 138, objective = 0.6934419415837569\n",
      "iteration 139, objective = 0.6930804955107825\n",
      "iteration 140, objective = 0.6927258566761694\n",
      "iteration 141, objective = 0.6923705221901477\n",
      "iteration 142, objective = 0.6920184622277746\n",
      "iteration 143, objective = 0.6916693349679311\n",
      "iteration 144, objective = 0.6913212266461602\n",
      "iteration 145, objective = 0.6909782911000186\n",
      "iteration 146, objective = 0.6906380300619164\n",
      "iteration 147, objective = 0.6902991138600014\n",
      "iteration 148, objective = 0.6899628427204669\n",
      "iteration 149, objective = 0.6896285080909729\n",
      "iteration 150, objective = 0.6892968755684151\n",
      "iteration 151, objective = 0.6889681804337\n",
      "iteration 152, objective = 0.6886415481567383\n",
      "iteration 153, objective = 0.6883151120953745\n",
      "iteration 154, objective = 0.6879938325574321\n",
      "iteration 155, objective = 0.6876733161700077\n",
      "iteration 156, objective = 0.6873554200123829\n",
      "iteration 157, objective = 0.6870390950124475\n",
      "iteration 158, objective = 0.6867239246578336\n",
      "iteration 159, objective = 0.686412651464343\n",
      "iteration 160, objective = 0.686103936307919\n",
      "iteration 161, objective = 0.6857951548364427\n",
      "iteration 162, objective = 0.685490209266452\n",
      "iteration 163, objective = 0.6851867680869451\n",
      "iteration 164, objective = 0.6848841009717999\n",
      "iteration 165, objective = 0.6845839644771025\n",
      "iteration 166, objective = 0.6842864323518947\n",
      "iteration 167, objective = 0.6839908680745533\n",
      "iteration 168, objective = 0.6836962125005102\n",
      "iteration 169, objective = 0.6834033555844251\n",
      "iteration 170, objective = 0.683112138893172\n",
      "iteration 171, objective = 0.6828214259341706\n",
      "iteration 172, objective = 0.6825320713781897\n",
      "iteration 173, objective = 0.6822467001005151\n",
      "iteration 174, objective = 0.6819614968981061\n",
      "iteration 175, objective = 0.6816778792576357\n",
      "iteration 176, objective = 0.681395521945199\n",
      "iteration 177, objective = 0.6811166268386198\n",
      "iteration 178, objective = 0.6808395385742188\n",
      "iteration 179, objective = 0.6805620825952954\n",
      "iteration 180, objective = 0.6802850856306804\n",
      "iteration 181, objective = 0.6800109724422078\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration 182, objective = 0.6797387873540159\n",
      "iteration 183, objective = 0.6794705980497858\n",
      "iteration 184, objective = 0.6792025868957107\n",
      "iteration 185, objective = 0.6789336790961604\n",
      "iteration 186, objective = 0.6786676715402042\n",
      "iteration 187, objective = 0.6784027957535804\n",
      "iteration 188, objective = 0.6781373998475453\n",
      "iteration 189, objective = 0.6778746128082276\n",
      "iteration 190, objective = 0.6776137870019643\n",
      "iteration 191, objective = 0.6773542103668054\n",
      "iteration 192, objective = 0.6770944255621322\n",
      "iteration 193, objective = 0.676837068857606\n",
      "iteration 194, objective = 0.676579177685273\n",
      "iteration 195, objective = 0.6763236592618787\n",
      "iteration 196, objective = 0.676069193987677\n",
      "iteration 197, objective = 0.6758159570621721\n",
      "iteration 198, objective = 0.6755619684056421\n",
      "iteration 199, objective = 0.6753104063868522\n",
      "iteration 200, objective = 0.675062860719007\n",
      "iteration 201, objective = 0.6748144709237731\n",
      "iteration 202, objective = 0.6745672883658573\n",
      "iteration 203, objective = 0.6743210013590607\n",
      "iteration 204, objective = 0.6740748949167205\n",
      "iteration 205, objective = 0.6738299860537631\n",
      "iteration 206, objective = 0.673587819804316\n",
      "iteration 207, objective = 0.6733452918437811\n",
      "iteration 208, objective = 0.6731058316367665\n",
      "iteration 209, objective = 0.6728673565955389\n",
      "iteration 210, objective = 0.6726272439504687\n",
      "iteration 211, objective = 0.6723903471006537\n",
      "iteration 212, objective = 0.6721533914686928\n",
      "iteration 213, objective = 0.6719170151469863\n",
      "iteration 214, objective = 0.6716823852339456\n",
      "iteration 215, objective = 0.6714475897175295\n",
      "iteration 216, objective = 0.6712161851918093\n",
      "iteration 217, objective = 0.6709852442828887\n",
      "iteration 218, objective = 0.6707553128673606\n",
      "iteration 219, objective = 0.6705267602747137\n",
      "iteration 220, objective = 0.6702987122859351\n",
      "iteration 221, objective = 0.6700726511242153\n",
      "iteration 222, objective = 0.6698443483344108\n",
      "iteration 223, objective = 0.669617451195206\n",
      "iteration 224, objective = 0.6693912842538622\n",
      "iteration 225, objective = 0.6691676520668299\n",
      "iteration 226, objective = 0.6689425383895504\n",
      "iteration 227, objective = 0.6687195847431818\n",
      "iteration 228, objective = 0.6684980946857336\n",
      "iteration 229, objective = 0.6682744982449905\n",
      "iteration 230, objective = 0.6680524942678806\n",
      "iteration 231, objective = 0.6678340103605698\n",
      "iteration 232, objective = 0.6676140755031242\n",
      "iteration 233, objective = 0.6673944327566359\n",
      "iteration 234, objective = 0.667177718750974\n",
      "iteration 235, objective = 0.6669618083258807\n",
      "iteration 236, objective = 0.6667446999610225\n",
      "iteration 237, objective = 0.6665296344196096\n",
      "iteration 238, objective = 0.6663144033823053\n",
      "iteration 239, objective = 0.6661002916594346\n",
      "iteration 240, objective = 0.6658883589431953\n",
      "iteration 241, objective = 0.6656769016557489\n",
      "iteration 242, objective = 0.6654644240567713\n",
      "iteration 243, objective = 0.6652520296515011\n",
      "iteration 244, objective = 0.6650414980187708\n",
      "iteration 245, objective = 0.6648304474547626\n",
      "iteration 246, objective = 0.664622322991792\n",
      "iteration 247, objective = 0.6644147531159462\n",
      "iteration 248, objective = 0.6642073505374801\n",
      "iteration 249, objective = 0.6640016286373138\n",
      "iteration 250, objective = 0.6637969686690555\n",
      "iteration 251, objective = 0.6635922299964088\n",
      "iteration 252, objective = 0.6633879193675377\n",
      "iteration 253, objective = 0.6631853568741656\n",
      "iteration 254, objective = 0.6629817296476925\n",
      "iteration 255, objective = 0.6627791111823171\n",
      "iteration 256, objective = 0.6625758504589244\n",
      "iteration 257, objective = 0.6623751338600188\n",
      "iteration 258, objective = 0.6621732258428478\n",
      "iteration 259, objective = 0.6619721894080822\n",
      "iteration 260, objective = 0.6617724927449135\n",
      "iteration 261, objective = 0.6615732596575759\n",
      "iteration 262, objective = 0.6613734548082824\n",
      "iteration 263, objective = 0.6611772678566702\n",
      "iteration 264, objective = 0.6609817205734972\n",
      "iteration 265, objective = 0.6607846953815087\n",
      "iteration 266, objective = 0.6605878179885921\n",
      "iteration 267, objective = 0.6603919252086041\n",
      "iteration 268, objective = 0.6601960472901072\n",
      "iteration 269, objective = 0.6600032960927045\n",
      "iteration 270, objective = 0.6598108184293627\n",
      "iteration 271, objective = 0.659617926706286\n",
      "iteration 272, objective = 0.6594267568308791\n",
      "iteration 273, objective = 0.6592361315758559\n",
      "iteration 274, objective = 0.659043824672699\n",
      "iteration 275, objective = 0.6588541223951008\n",
      "iteration 276, objective = 0.6586609258978806\n",
      "iteration 277, objective = 0.6584714329499992\n",
      "iteration 278, objective = 0.6582812893347928\n",
      "iteration 279, objective = 0.6580922586577279\n",
      "iteration 280, objective = 0.6579055141299645\n",
      "iteration 281, objective = 0.6577175971886791\n",
      "iteration 282, objective = 0.6575313762304218\n",
      "iteration 283, objective = 0.6573448479175568\n",
      "iteration 284, objective = 0.6571604143109238\n",
      "iteration 285, objective = 0.6569741912238247\n",
      "iteration 286, objective = 0.6567891065667315\n",
      "iteration 287, objective = 0.6566054769274261\n",
      "iteration 288, objective = 0.6564221996749561\n",
      "iteration 289, objective = 0.6562392867844681\n",
      "iteration 290, objective = 0.6560570042567564\n",
      "iteration 291, objective = 0.6558766795756066\n",
      "iteration 292, objective = 0.6556946256868669\n",
      "iteration 293, objective = 0.6555116340822104\n",
      "iteration 294, objective = 0.6553311691445819\n",
      "iteration 295, objective = 0.6551494342652527\n",
      "iteration 296, objective = 0.6549679992174862\n",
      "iteration 297, objective = 0.654788577516607\n",
      "iteration 298, objective = 0.6546111726840603\n",
      "iteration 299, objective = 0.6544300723075867\n",
      "iteration 300, objective = 0.6542526529080844\n",
      "iteration 301, objective = 0.6540752327600062\n",
      "iteration 302, objective = 0.6538976865633093\n",
      "iteration 303, objective = 0.6537209749221802\n",
      "iteration 304, objective = 0.6535442973746628\n",
      "iteration 305, objective = 0.6533689249574749\n",
      "iteration 306, objective = 0.6531922510470164\n",
      "iteration 307, objective = 0.6530191215214791\n",
      "iteration 308, objective = 0.6528429485447584\n",
      "iteration 309, objective = 0.6526686410750112\n",
      "iteration 310, objective = 0.6524955131999932\n",
      "iteration 311, objective = 0.6523241437016389\n",
      "iteration 312, objective = 0.6521511095019575\n",
      "iteration 313, objective = 0.6519782681753681\n",
      "iteration 314, objective = 0.6518070868083409\n",
      "iteration 315, objective = 0.6516358046969281\n",
      "iteration 316, objective = 0.6514634764532937\n",
      "iteration 317, objective = 0.6512923459961729\n",
      "iteration 318, objective = 0.6511212589225052\n",
      "iteration 319, objective = 0.6509511608630418\n",
      "iteration 320, objective = 0.6507811965972092\n",
      "iteration 321, objective = 0.650613717595983\n",
      "iteration 322, objective = 0.6504443629238259\n",
      "iteration 323, objective = 0.6502758641669779\n",
      "iteration 324, objective = 0.6501080327767592\n",
      "iteration 325, objective = 0.649940320684866\n",
      "iteration 326, objective = 0.6497723624246929\n",
      "iteration 327, objective = 0.6496051907902811\n",
      "iteration 328, objective = 0.6494385938876306\n",
      "iteration 329, objective = 0.6492726705291054\n",
      "iteration 330, objective = 0.6491067834491095\n",
      "iteration 331, objective = 0.6489419183099127\n",
      "iteration 332, objective = 0.6487766119452926\n",
      "iteration 333, objective = 0.6486122435081505\n",
      "iteration 334, objective = 0.6484466832075546\n",
      "iteration 335, objective = 0.6482824658354124\n",
      "iteration 336, objective = 0.6481185872052473\n",
      "iteration 337, objective = 0.6479547448764892\n",
      "iteration 338, objective = 0.6477927217441323\n",
      "iteration 339, objective = 0.6476284593343735\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-69-3e439fcb3bbb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m     \u001b[0mobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainingStep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minteractions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0mlosses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlosses\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     print(\"iteration \" + str(i) + \", objective = \" +\n",
      "\u001b[0;32m<ipython-input-68-67918c84b1a4>\u001b[0m in \u001b[0;36mtrainingStep\u001b[0;34m(interactions)\u001b[0m\n\u001b[1;32m     12\u001b[0m             \u001b[0msampleK\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitemIDs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msampleU\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msampleI\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msampleJ\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msampleK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0mgradients\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgradient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainable_variables\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    817\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    818\u001b[0m           \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_build\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 819\u001b[0;31m           \u001b[0mcast_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_cast_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    820\u001b[0m           with base_layer_utils.autocast_context_manager(\n\u001b[1;32m    821\u001b[0m               self._compute_dtype):\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m_maybe_cast_inputs\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   1762\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1763\u001b[0m           \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1764\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1765\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1766\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/util/nest.py\u001b[0m in \u001b[0;36mmap_structure\u001b[0;34m(func, *structure, **kwargs)\u001b[0m\n\u001b[1;32m    566\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    567\u001b[0m   return pack_sequence_as(\n\u001b[0;32m--> 568\u001b[0;31m       \u001b[0mstructure\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mentries\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    569\u001b[0m       expand_composites=expand_composites)\n\u001b[1;32m    570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/util/nest.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    566\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    567\u001b[0m   return pack_sequence_as(\n\u001b[0;32m--> 568\u001b[0;31m       \u001b[0mstructure\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mentries\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    569\u001b[0m       expand_composites=expand_composites)\n\u001b[1;32m    570\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m   1751\u001b[0m         cast_types = (ops.Tensor, sparse_tensor.SparseTensor,\n\u001b[1;32m   1752\u001b[0m                       ragged_tensor.RaggedTensor)\n\u001b[0;32m-> 1753\u001b[0;31m         if (isinstance(x, cast_types) and x.dtype.is_floating and\n\u001b[0m\u001b[1;32m   1754\u001b[0m             x.dtype.base_dtype.name != compute_dtype):\n\u001b[1;32m   1755\u001b[0m           \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dtype_defaulted_to_floatx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/dtypes.py\u001b[0m in \u001b[0;36mis_floating\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    147\u001b[0m     \u001b[0;34m\"\"\"Returns whether this is a (non-quantized, real) floating point type.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    148\u001b[0m     return ((self.is_numpy_compatible and\n\u001b[0;32m--> 149\u001b[0;31m              np.issubdtype(self.as_numpy_dtype, np.floating)) or\n\u001b[0m\u001b[1;32m    150\u001b[0m             self.base_dtype == bfloat16)\n\u001b[1;32m    151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/dtypes.py\u001b[0m in \u001b[0;36mas_numpy_dtype\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    125\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0mas_numpy_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m     \u001b[0;34m\"\"\"Returns a `numpy.dtype` based on this `DType`.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_TF_TO_NP\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_enum\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "losses = []\n",
    "for i in range(1000):\n",
    "    obj = trainingStep(interactions)\n",
    "    losses = (losses + [obj])[-1000:]\n",
    "    print(\"iteration \" + str(i) + \", objective = \" +\n",
    "          str(sum(losses) / len(losses)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
