Source code for xenonpy.mdl.base

#  Copyright (c) 2021. yoshida-lab. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.

import json
from os import remove
from pathlib import Path
from shutil import make_archive

import pandas as pd
import requests
from requests import HTTPError
from sklearn.base import BaseEstimator

from xenonpy.utils import TimedMetaClass


[docs]class BaseQuery(BaseEstimator, metaclass=TimedMetaClass): queryable = None def __init__(self, variables, *, api_key: str = 'anonymous.user.key', endpoint: str = 'http://xenon.ism.ac.jp/api'): if self.queryable is None: raise RuntimeError('Query class must give a queryable field in list of string') self._results = None self._return_json = False self._endpoint = endpoint self._api_key = api_key self._variables = variables @property def api_key(self): return self._api_key @property def endpoint(self): return self._endpoint @property def variables(self): return self._variables @property def results(self): return self._results
[docs] def gql(self, *query_vars: str): raise NotImplementedError()
@staticmethod def _post(ret, return_json): if return_json: return ret if not isinstance(ret, list): ret = [ret] ret = pd.DataFrame(ret) return ret
[docs] def check_query_vars(self, *query_vars: str): if not set(query_vars) <= set(self.queryable): raise RuntimeError(f'`query_vars` contains illegal variables, ' f'available querying variables are: {self.queryable}') return query_vars
[docs] def __call__(self, *querying_vars, file=None, return_json=None): if len(querying_vars) == 0: query = self.gql(*self.queryable) else: query = self.gql(*self.check_query_vars(*querying_vars)) payload = json.dumps({'query': query, 'variables': self._variables}) if file is None: ret = requests.post(url=self._endpoint, data=payload, headers={ "content-type": "application/json", 'api_key': self._api_key }) else: file = Path(file).resolve() file = make_archive(str(file), 'gztar', str(file)) operations = ('operations', payload) maps = ('map', json.dumps({0: ['variables.model']})) payload_tuples = (operations, maps) files = {'0': open(file, 'rb')} try: ret = requests.post(url=self._endpoint, data=payload_tuples, headers={'api_key': self._api_key}, files=files) finally: files['0'].close() remove(file) if ret.status_code != 200: try: message = ret.json() except json.JSONDecodeError: message = "Server did not responce." raise HTTPError('status_code: %s, %s' % (ret.status_code, message)) ret = ret.json() if 'errors' in ret: raise ValueError(ret['errors'][0]['message']) query_name = self.__class__.__name__ ret = ret['data'][query_name[0].lower() + query_name[1:]] if not ret: return None if return_json is None: return_json = self._return_json ret = self._post(ret, return_json) self._results = ret return ret
def __repr__(self, N_CHAR_MAX=700): queryable = '\n '.join(self.queryable) return f'{super().__repr__(N_CHAR_MAX)}\nQueryable: \n {queryable}'