Source code for pimlico.core.external.java

# This file is part of Pimlico
# Copyright (C) 2020 Mark Granroth-Wilding
# Licensed under the GNU LGPL v3.0 - https://www.gnu.org/licenses/lgpl-3.0.en.html

from __future__ import print_function

import logging

from future import standard_library
standard_library.install_aliases()
from builtins import zip
from builtins import str
from builtins import object

import sys
import time
from queue import Empty
from io import StringIO
from collections import deque
from subprocess import Popen, PIPE
from threading import Thread

import os
from pimlico import JAVA_LIB_DIR, JAVA_BUILD_JAR_DIR
from pimlico.core.logs import get_log_file
from pimlico.utils.pipes import qget

ALWAYS_INCLUDE_IN_CLASSPATH = ["%s/*" % JAVA_BUILD_JAR_DIR]
DEFAULT_CLASSPATH = ":".join(["%s/*" % JAVA_LIB_DIR] + ALWAYS_INCLUDE_IN_CLASSPATH)


[docs]def call_java(class_name, args=[], classpath=None): # May in future want to allow the path to the java executable to be specified in local config process = Popen(java_call_command(class_name, classpath=classpath) + args, stdin=PIPE, stdout=PIPE, stderr=PIPE, shell=False) stdout_data, stderr_data = process.communicate() stdout_data = stdout_data.decode("utf-8") stderr_data = stderr_data.decode("utf-8") return stdout_data, stderr_data, process.returncode
[docs]def java_call_command(class_name, classpath=None): """List of components for a subprocess call to Java, used by call_java """ if classpath is None: classpath = DEFAULT_CLASSPATH else: classpath = ":".join([classpath] + ALWAYS_INCLUDE_IN_CLASSPATH) # May in future want to allow the path to the java executable to be specified in local config return ["java", "-cp", classpath, class_name]
[docs]def start_java_process(class_name, args=[], java_args=[], wait=0.1, classpath=None): if classpath is None: classpath = DEFAULT_CLASSPATH else: classpath = ":".join([classpath] + ALWAYS_INCLUDE_IN_CLASSPATH) # May in future want to allow the path to the java executable to be specified in local config cmd = ["java", "-cp", classpath] + java_args + [class_name] + args process = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, shell=False) # Attach the command to the Popen object so it's easy to read what was run for debugging process.command_run = " ".join(cmd) # Wait a mo for it to get started time.sleep(wait) if process.poll() is not None: if process.returncode == 0: raise JavaProcessError("java process exited immediately with return code 0: %s (ran: %s)" % (process.stderr.read(), " ".join(cmd))) else: raise JavaProcessError("java process failed with return code %d: %s (ran: %s)" % (process.returncode, process.stderr.read(), " ".join(cmd))) return process
[docs]class Py4JInterface(object): def __init__(self, gateway_class, port=None, python_port=None, gateway_args=[], pipeline=None, print_stdout=True, print_stderr=True, env={}, system_properties={}, java_opts=[], timeout=10., prefix_classpath=None): """ If pipeline is given, configuration is looked for there. If found, this overrides config given in other kwargs. If print_stdout=True (default), stdout from processes will be printed out to the console in addition to any other processing that's done to it. Same with stderr. By default, both are output to the console. env adds extra variables to the environment for running the Java process. system_properties adds Java system property settings to the Java command. """ from py4j.compat import Queue self.prefix_classpath = prefix_classpath self.java_opts = java_opts self.system_properties = system_properties self.env = env self.print_stderr = print_stderr self.print_stdout = print_stdout self.python_port = python_port self.gateway_args = gateway_args self.gateway_class = gateway_class self.port = port self.timeout = timeout self.stderr_queue = Queue() self.stdout_queue = Queue() # Look for config in the pipeline if pipeline is not None: start_port = pipeline.local_config.get("py4j_port", None) if start_port is not None: # Config gives just a single port number # If it's given, use the following port for the other direction of communication self.port = int(start_port) self.python_port = int(start_port) + 1 if "py4j_timeout" in pipeline.local_config: # Override the timeout given as an arg self.timeout = float(pipeline.local_config["py4j_timeout"]) self.process = None self.gateway = None self.port_used = None self.clients = []
[docs] def start(self, timeout=None, port_output_prefix=None): """ Start a Py4J gateway server in the background on the given port, which will then be used for communicating with the Java app. If a port has been given, it is assumed that the gateway accepts a --port option. Likewise with python_port and a --python-port option. If timeout is given, it overrides any timeout given in the constructor or specified in local config. """ if timeout is None: # Use the default timeout specified on the instance timeout = self.timeout args = list(self.gateway_args) if self.port is not None: args.extend(["--port", "%d" % self.port]) if self.python_port is not None: args.extend(["--python-port", "%d" % self.python_port]) # We could add other things as well here, like queues, to capture the output redirect_stdout = [self.stdout_queue] if self.print_stdout: redirect_stdout.append(sys.stdout) redirect_stderr = [self.stderr_queue] if self.print_stderr: redirect_stderr.append(sys.stderr) # Allow Java system properties to be set on the command line java_opts = list(self.java_opts) for prop, val in self.system_properties.items(): java_opts.extend(["-D%s=%s" % (prop, val)]) self.port_used, self.process = launch_gateway( self.gateway_class, args, redirect_stdout=redirect_stdout, redirect_stderr=redirect_stderr, env=self.env, javaopts=java_opts, startup_timeout=timeout, port_output_prefix=port_output_prefix, prefix_classpath=self.prefix_classpath, ) self.gateway = self.new_client()
[docs] def new_client(self): client = gateway_client_to_running_server(self.port_used) self.clients.append(client) return client
[docs] def stop(self): # Stop the client gateway(s) for client_gateway in self.clients: client_gateway.close() self.gateway = None self.clients = [] # Stop the server process try: self.process.terminate() except OSError as e: if e.errno == 3: # No such process: process is already dead pass else: # Raise other errors raise self.process.wait()
[docs] def clear_output_queues(self): while not self.stdout_queue.empty(): self.stdout_queue.get_nowait() while not self.stderr_queue.empty(): self.stderr_queue.get_nowait()
def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop()
[docs]def no_retry_gateway(**kwargs): """ A wrapper around the constructor of JavaGateway that produces a version of it that doesn't retry on errors. The default gateway keeps retying and outputting millions of errors if the server goes down, which makes responding to interrupts horrible (as the server might die before the Python process gets the interrupt). TODO This isn't working: it just gets worse when I use my version! """ from py4j.java_gateway import JavaGateway, GatewayClient class NoRetryGatewayClient(GatewayClient): def send_command(self, command, retry=True): return super(NoRetryGatewayClient, self).send_command(command, retry=False) # Start a JavaGateway as normal gateway = JavaGateway(**kwargs) # Replace the gateway client #gateway_client = NoRetryGatewayClient( # address=gateway._gateway_client.address, # port=gateway._gateway_client.port, # auto_close=gateway._gateway_client.auto_close #) #gateway.set_gateway_client(gateway_client) return gateway
[docs]def gateway_client_to_running_server(port): from py4j.java_gateway import GatewayParameters return no_retry_gateway(gateway_parameters=GatewayParameters(port=port))
[docs]def launch_gateway(gateway_class="py4j.GatewayServer", args=[], javaopts=[], redirect_stdout=None, redirect_stderr=None, daemonize_redirect=True, env={}, port_output_prefix=None, startup_timeout=10., prefix_classpath=None): """ Our own more flexble version of Py4J's launch_gateway. """ from py4j.java_gateway import ProcessConsumer from py4j.compat import Queue # Make sure that java_gateway logger isn't set to debug, as it outputs tonnes from py4j.java_gateway import logger as java_gateway_logger if java_gateway_logger.level <= logging.DEBUG: java_gateway_logger.setLevel(logging.INFO) if startup_timeout is None: startup_timeout = 10. # Add custom environment variables to the ones we've already got java_env = os.environ.copy() java_env.update(env) # Allow extra things to be added to the start of the classpath if prefix_classpath is not None: classpath = ":".join(prefix_classpath.split(":") + DEFAULT_CLASSPATH.split(":")) else: classpath = DEFAULT_CLASSPATH # Launch the server in a subprocess. command = ["java", "-classpath", classpath] + javaopts + [gateway_class] + args proc = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE, env=java_env) if redirect_stdout is None: redirect_stdout = [] stdout_queue = Queue() if redirect_stderr is None: redirect_stderr = [] stderr_capture = StringIO() def get_stderr(): try: return stderr_capture.getvalue() except Empty: return "" # Start consumer threads so process does not deadlock/hang stdout_consumer = OutputConsumer(redirect_stdout, proc.stdout, daemon=daemonize_redirect, temporary_redirects=[stdout_queue]) stderr_consumer = OutputConsumer(redirect_stderr, proc.stderr, daemon=daemonize_redirect, temporary_redirects=[stderr_capture]) stdout_consumer.start() stderr_consumer.start() ProcessConsumer(proc, [redirect_stdout], daemon=daemonize_redirect).start() # Determine which port the server started on (needed to support ephemeral ports) # Don't hang on an error running the gateway launcher output = None start_time = time.time() remaining_time = time.time() - start_time + startup_timeout try: while remaining_time > 0.: # Get lines from stdout until the timeout is reached output = qget(stdout_queue, timeout=remaining_time) if port_output_prefix is None or output.strip("\n ") == "ERROR": # Don't look for a particular prefix, just use the first line we get, or if it's an error # don't keep waiting break elif output.startswith(port_output_prefix): # If the line doesn't begin with the right prefix, keep waiting # If it does, strip the prefix, so that we get just the port number output = output[len(port_output_prefix):] break remaining_time = time.time() - start_time + startup_timeout except Empty: # Timed out waiting for server to start error_output = get_stderr() err_path = output_p4j_error_info(command, "(timed out)", "", error_output) try: # Give up and kill the process, if it's still running proc.terminate() except OSError: # Already terminated pass raise JavaProcessError("timed out starting gateway server (for details see %s)" % err_path) except Exception as e: # Try reading stderr to see if there's any info there error_output = get_stderr() err_path = output_p4j_error_info(command, "?", "could not read", error_output) raise JavaProcessError("error reading first line from gateway process: %s. Error output: %s (see %s for " "more details)" % (e, error_output, err_path)) # Check whether there was an error reported output = output.strip("\n ") if output == "ERROR": # Read error output from stderr error_output = get_stderr() raise JavaProcessError("Py4J gateway had an error starting up: %s" % error_output) try: port_used = int(output) except ValueError: returncode = proc.poll() stderr_output = get_stderr() err_path = output_p4j_error_info(command, returncode, output, stderr_output) if returncode is not None: raise JavaProcessError("Py4J server process returned with return code %s: %s (see %s for details)" % (returncode, stderr_output, err_path)) else: raise JavaProcessError("invalid first line output from Py4J server when started: '%s' (see %s for details)" % (output, err_path)) # Stop the temporary redirects that were only going to capture startup output stdout_consumer.remove_temporary_redirects() stderr_consumer.remove_temporary_redirects() return port_used, proc
def _pipe_queue(redirect, line): redirect.put(line) def _pipe_deque(redirect, line): redirect.appendleft(line) def _pipe_fd(redirect, line): redirect.write(line)
[docs]def get_redirect_func(redirect): from py4j.compat import hasattr2 from py4j.compat import Queue if isinstance(redirect, Queue): return _pipe_queue if isinstance(redirect, deque): return _pipe_deque if hasattr2(redirect, "write"): return _pipe_fd
[docs]class OutputConsumer(Thread): """Thread that consumes output Modification of Py4J's OutputConsumer to allow multiple redirects. """ def __init__(self, redirects, stream, *args, **kwargs): self.temporary_redirects = kwargs.pop("temporary_redirects", []) daemon = kwargs.pop("daemon", False) super(OutputConsumer, self).__init__(*args, **kwargs) self.daemon = daemon # Also allow the one-redirect case, just like Py4J's class if not isinstance(redirects, list): redirects = [redirects] self.redirects = redirects self.stream = stream self.redirect_funcs = [get_redirect_func(redirect) for redirect in self.redirects] self.temporary_redirect_funcs = [get_redirect_func(redirect) for redirect in self.temporary_redirects]
[docs] def remove_temporary_redirects(self): self.temporary_redirects = [] self.temporary_redirect_funcs = []
[docs] def run(self): from py4j.protocol import smart_decode lines_iterator = iter(self.stream.readline, b"") for line in lines_iterator: for redirect, fn in zip(self.redirects, self.redirect_funcs): fn(redirect, smart_decode(line)) for redirect, fn in zip(self.temporary_redirects, self.temporary_redirect_funcs): fn(redirect, smart_decode(line))
[docs]def output_p4j_error_info(command, returncode, stdout, stderr): file_path = get_log_file("py4j") with open(file_path, "w") as f: print("Command:", file=f) print(" ".join(command), file=f) print("Return code: %s" % returncode, file=f) print("Read from stdout:", file=f) print(stdout, file=f) print("Read from stderr:", file=f) print(stderr, file=f) return file_path
[docs]def make_py4j_errors_safe(fn): """ Decorator for functions/methods that call Py4J. Py4J's exceptions include information that gets retrieved from the Py4J server when they're displayed. This is a problem if the server is not longer running and raises another exception, making the whole situation very confusing. If you wrap your function with this, Py4JJavaErrors will be replaced by our own exception type Py4JSafeJavaError, containing some of the information about the Java exception if possible. """ from py4j.protocol import Py4JJavaError def _wrapped_fn(*args, **kwargs): try: return fn(*args, **kwargs) except Py4JJavaError as e: # Try getting the java exception and string repr, but don't throw everything in the air if we can't try: java_exception = str(e.java_exception) except: java_exception = None try: str_repr = str(e) except: str_repr = None raise Py4JSafeJavaError(java_exception, str_repr) return _wrapped_fn
[docs]class Py4JSafeJavaError(Exception): def __init__(self, java_exception=None, str=None): super(Py4JSafeJavaError, self).__init__() self.str = str self.java_exception = java_exception def __str__(self): if self.str is not None: return self.str else: return super(Py4JSafeJavaError, self).__str__()
[docs]class DependencyCheckerError(Exception): pass
[docs]class JavaProcessError(Exception): pass