| #!/usr/bin/env python |
| # -*- coding: utf-8 -*- |
| # |
| # Copyright 2014 Google Inc. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Simple command-line sample for the Google Prediction API |
| |
| Command-line application that trains on your input data. This sample does |
| the same thing as the Hello Prediction! example. You might want to run |
| the setup.sh script to load the sample data to Google Storage. |
| |
| Usage: |
| $ python prediction.py "bucket/object" "model_id" "project_id" |
| |
| You can also get help on all the command-line flags the program understands |
| by running: |
| |
| $ python prediction.py --help |
| |
| To get detailed log output run: |
| |
| $ python prediction.py --logging_level=DEBUG |
| """ |
| from __future__ import print_function |
| |
| __author__ = ('[email protected] (Joe Gregorio), ' |
| '[email protected] (Marc Cohen)') |
| |
| import argparse |
| import pprint |
| import sys |
| import time |
| |
| from apiclient import sample_tools |
| from oauth2client import client |
| |
| |
| # Time to wait (in seconds) between successive checks of training status. |
| SLEEP_TIME = 10 |
| |
| |
| # Declare command-line flags. |
| argparser = argparse.ArgumentParser(add_help=False) |
| argparser.add_argument('object_name', |
| help='Full Google Storage path of csv data (ex bucket/object)') |
| argparser.add_argument('model_id', |
| help='Model Id of your choosing to name trained model') |
| argparser.add_argument('project_id', |
| help='Project Id of your Google Cloud Project') |
| |
| |
| def print_header(line): |
| '''Format and print header block sized to length of line''' |
| header_str = '=' |
| header_line = header_str * len(line) |
| print('\n' + header_line) |
| print(line) |
| print(header_line) |
| |
| |
| def main(argv): |
| # If you previously ran this app with an earlier version of the API |
| # or if you change the list of scopes below, revoke your app's permission |
| # here: https://accounts.google.com/IssuedAuthSubTokens |
| # Then re-run the app to re-authorize it. |
| service, flags = sample_tools.init( |
| argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser], |
| scope=( |
| 'https://www.googleapis.com/auth/prediction', |
| 'https://www.googleapis.com/auth/devstorage.read_only')) |
| |
| try: |
| # Get access to the Prediction API. |
| papi = service.trainedmodels() |
| |
| # List models. |
| print_header('Fetching list of first ten models') |
| result = papi.list(maxResults=10, project=flags.project_id).execute() |
| print('List results:') |
| pprint.pprint(result) |
| |
| # Start training request on a data set. |
| print_header('Submitting model training request') |
| body = {'id': flags.model_id, 'storageDataLocation': flags.object_name} |
| start = papi.insert(body=body, project=flags.project_id).execute() |
| print('Training results:') |
| pprint.pprint(start) |
| |
| # Wait for the training to complete. |
| print_header('Waiting for training to complete') |
| while True: |
| status = papi.get(id=flags.model_id, project=flags.project_id).execute() |
| state = status['trainingStatus'] |
| print('Training state: ' + state) |
| if state == 'DONE': |
| break |
| elif state == 'RUNNING': |
| time.sleep(SLEEP_TIME) |
| continue |
| else: |
| raise Exception('Training Error: ' + state) |
| |
| # Job has completed. |
| print('Training completed:') |
| pprint.pprint(status) |
| break |
| |
| # Describe model. |
| print_header('Fetching model description') |
| result = papi.analyze(id=flags.model_id, project=flags.project_id).execute() |
| print('Analyze results:') |
| pprint.pprint(result) |
| |
| # Make some predictions using the newly trained model. |
| print_header('Making some predictions') |
| for sample_text in ['mucho bueno', 'bonjour, mon cher ami']: |
| body = {'input': {'csvInstance': [sample_text]}} |
| result = papi.predict( |
| body=body, id=flags.model_id, project=flags.project_id).execute() |
| print('Prediction results for "%s"...' % sample_text) |
| pprint.pprint(result) |
| |
| # Delete model. |
| print_header('Deleting model') |
| result = papi.delete(id=flags.model_id, project=flags.project_id).execute() |
| print('Model deleted.') |
| |
| except client.AccessTokenRefreshError: |
| print ('The credentials have been revoked or expired, please re-run ' |
| 'the application to re-authorize.') |
| |
| |
| if __name__ == '__main__': |
| main(sys.argv) |