Ad
Python - Scipy: Multivariate_normal - Select The Right Subsets Of Input
Any help that pushes me towards the right solution is greatly appreciated...
I am trying to do a classification in two steps:
1.) Calculate mu, sigma, and pi on the training set. 2.) Create a test routine, that takes
- mu, sigma, pi
- an array of Feature IDs
- testx and testy.
Part 1.) works. It returns - mu # shape 4,13 - sigma # shape 4,13,13 - pi # shape 4,
def fit_generative_model(x,y):
k = 3 # labels 1,2,...,k
d = (x.shape)[1] # number of features
mu = np.zeros((k+1,d))
sigma = np.zeros((k+1,d,d))
pi = np.zeros(k+1)
for label in range(1,k+1):
indices = (y == label)
mu[label] = np.mean(x[indices,:], axis=0)
sigma[label] = np.cov(x[indices,:], rowvar=0, bias=1)
pi[label] = float(sum(indices))/float(len(y))
return mu, sigma, pi
Part 2.) does not work, as I seem to be unable to select the right subsets of mu and sigma
def test_model(mu, sigma, pi, features, tx, ty):
mu, sigma, pi = fit_generative_model(trainx,trainy)
# set the variables
k = 3 # Labels 1,2,...,k
nt = len(testy)
score = np.zeros((nt,k+1))
covar = sigma
for i in range(0,nt):
for label in range(1,k+1):
score[i,label] = np.log(pi[label]) + \
multivariate_normal.logpdf(testx[i,features], mean=mu[label,:], cov=covar[label,:,:])
predictions = np.argmax(score[:,1:4], axis=1) + 1
errors = np.sum(predictions != testy)
return errors
It should return the number of mistakes made by the generative model on the test data when restricted to the specified features.
Ad
Answer
Try this. It should work.
mean=mu[label,features], cov=covar[label,features,features]
Ad
source: stackoverflow.com
Related Questions
- → What are the pluses/minuses of different ways to configure GPIOs on the Beaglebone Black?
- → Django, code inside <script> tag doesn't work in a template
- → React - Django webpack config with dynamic 'output'
- → GAE Python app - Does URL matter for SEO?
- → Put a Rendered Django Template in Json along with some other items
- → session disappears when request is sent from fetch
- → Python Shopify API output formatted datetime string in django template
- → Can't turn off Javascript using Selenium
- → WebDriver click() vs JavaScript click()
- → Shopify app: adding a new shipping address via webhook
- → Shopify + Python library: how to create new shipping address
- → shopify python api: how do add new assets to published theme?
- → Access 'HTTP_X_SHOPIFY_SHOP_API_CALL_LIMIT' with Python Shopify Module
Ad