from collections import defaultdict
import numpy

import hungarian

import sys

a = open("/afs/cs.stanford.edu/u/jmcauley/scratch/beercode/language/yelpExp/language/yelp_phoenix_academic_dataset/yelp_academic_dataset_business.json", 'r')
cats = set()
#out = open("categories.txt", 'w')
ccount = defaultdict(int)
bcats = {}
for l in a.readlines():
  #print l
  l = eval(l.replace("true", "True").replace("false", "False"))
  
  #out.write(l['business_id'] + ' ' + ' '.join([x.replace(' ', '_') for x in l['categories']]) + '\n')
  bcats[l['business_id']] = set(l['categories'])
  for x in set(l['categories']):
    ccount[x] += 1
    cats.add(x.replace(' ', '_'))
#out.close()
topcats_ = [(ccount[x],x) for x in ccount.keys()]
topcats_.sort()
topcats_.reverse()
#print topcats
#print len(cats)

#for b in bcats.keys():
  #cs = set([cat for cat in bcats[b] if cat in topcats])
  #out.write(b + ' ' + ' '.join([cat.replace(' ', '_') for cat in cs]) + '\n')
#out.close()

def processModel(modelFile):
  bestCat = {}
  a = open(modelFile, 'r')
  for l in a.readlines():
    l = l.strip().split()
    b = l[0]
    feat = [float(x) for x in l[1:]]
    K = len(feat)
    bestCat[b] = feat.index(max(feat))
  return bestCat
  
def F1(s1, s2):
  if (len(s1) == 0 or len(s2) == 0): return 0
  p = len(s1.intersection(s2)) * 1.0 / len(s2)
  r = len(s1.intersection(s2)) * 1.0 / len(s1)
  if (p == 0 and r == 0): return 0
  return 2 * (p*r) / (p + r)

  
def improvement(a, b):
  imp = (a - b) / b
  s = str(imp * 100)
  return s[:s.find('.') + 3] + '\\%', imp*100

def greedy(modelFile, bcats, topcats):
  bestCat = processModel(modelFile)
  bs = bcats.keys()
  mat = numpy.zeros((K,K))
  for k1 in range(K):
    businessesWithCat = set()
    for b in bs:
      if topcats[k1] in bcats[b]:
        businessesWithCat.add(b)
    for k2 in range(K):
      businessesWithBest = set([b for b in bestCat.keys() if bestCat[b] == k2])
      mat[k1,k2] = F1(businessesWithCat, businessesWithBest)
  tup_row,tup_col = hungarian.lap(-mat)
  return sum([mat[i,tup_row[i]] for i in range(K)]) / K

for K in 5,10,20,50:
  topcats = [x[1] for x in topcats_][:K]
  gLFR = greedy("model_1_0_" + str(K),bcats,topcats)
  gLDA = greedy("models/lda" + str(K) + ".model",bcats,topcats)
  gMe = greedy("model_0_1_" + str(K),bcats,topcats)
  print str(K) + " & " + str(gLFR)[:5] + ' & ' + str(gLDA)[:5] + ' & ' + str(gMe)[:5] + ' & ' + improvement(gMe, gLFR)[0] + ' & ' + improvement(gMe, gLDA)[0] + '\\\\'
