diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..b678f571 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,64 @@ +name: mycli + +on: + pull_request: + paths-ignore: + - '**.md' + +jobs: + linux: + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + include: + - python-version: 3.6 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.7 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.8 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.9 + os: ubuntu-20.04 # MySQL 8.0.22 + + runs-on: ${{ matrix.os }} + steps: + + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Start MySQL + run: | + sudo /etc/init.d/mysql start + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + pip install --no-cache-dir -e . + + - name: Wait for MySQL connection + run: | + while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do + sleep 5 + done + + - name: Pytest / behave + env: + PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 + run: | + ./setup.py test --pytest-args="--cov-report= --cov=mycli" + + - name: Lint + run: | + ./setup.py lint --branch=HEAD + + - name: Coverage + run: | + coverage combine + coverage report + codecov diff --git a/.gitignore b/.gitignore index e1d59532..b13429e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea/ +.vscode/ /build /dist /mycli.egg-info diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 3b4d98ca..00000000 --- a/.travis.yml +++ /dev/null @@ -1,34 +0,0 @@ -language: python -python: - - "2.7" - - "3.4" - - "3.5" - - "3.6" - -matrix: - include: - - python: 3.7 - dist: xenial - sudo: true - -install: - - pip install -r requirements-dev.txt - - pip install -e . - - sudo rm -f /etc/mysql/conf.d/performance-schema.cnf - - sudo service mysql restart - -script: - - ./setup.py test --pytest-args="--cov-report= --cov=mycli" - - coverage combine - - coverage report - - ./setup.py lint --branch=$TRAVIS_BRANCH - -after_success: - - codecov - -notifications: - webhooks: - urls: - - YOUR_WEBHOOK_URL - on_success: change # options: [always|never|change] default: always - on_failure: always # options: [always|never|change] default: always diff --git a/README.md b/README.md index b2426db5..cc04a910 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # mycli -[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli) -[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli) -[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli) +[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli) +[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python) A command line client for MySQL that can do auto-completion and syntax highlighting. -HomePage: [http://mycli.net](http://mycli.net) +HomePage: [http://mycli.net](http://mycli.net) Documentation: [http://mycli.net/docs](http://mycli.net/docs) ![Completion](screenshots/tables.png) @@ -53,6 +53,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu -h, --host TEXT Host address of the database. -P, --port INTEGER Port number to use for connection. Honors $MYSQL_TCP_PORT. + -u, --user TEXT User name to connect to the database. -S, --socket TEXT The socket file to use for connection. -p, --password TEXT Password to connect to the database. @@ -63,6 +64,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu --ssh-password TEXT Password to connect to ssh server. --ssh-key-filename TEXT Private key filename (identify file) for the ssh connection. + + --ssh-config-path TEXT Path to ssh configuration. + --ssh-config-host TEXT Host to connect to ssh server reading from ssh + configuration. + --ssl-ca PATH CA file in PEM format. --ssl-capath TEXT CA directory. --ssl-cert PATH X509 cert in PEM format. @@ -71,30 +77,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu --ssl-verify-server-cert Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default. + -V, --version Output mycli's version. -v, --verbose Verbose output. -D, --database TEXT Database to use. -d, --dsn TEXT Use DSN configured into the [alias_dsn] section of myclirc file. + --list-dsn list of DSN configured into the [alias_dsn] section of myclirc file. + + --list-ssh-config list ssh configurations in the ssh config + (requires paramiko). + -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). -l, --logfile FILENAME Log every query and its results to a file. --defaults-group-suffix TEXT Read MySQL config groups with the specified suffix. + --defaults-file PATH Only read MySQL options from the given file. --myclirc PATH Location of myclirc file. --auto-vertical-output Automatically switch to vertical output mode if the result is wider than the terminal width. + -t, --table Display batch output in table format. --csv Display batch output in CSV format. --warn / --no-warn Warn before running a destructive query. --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. - --login-path TEXT Read this path from the login file. + -g, --login-path TEXT Read this path from the login file. -e, --execute TEXT Execute command and quit. + --init-command TEXT SQL statement to execute after connecting. + --charset TEXT Character set for MySQL session. + --password-file PATH File or FIFO path containing the password + to connect to the db if not specified otherwise --help Show this message and exit. + Features -------- @@ -109,7 +128,7 @@ Features * Support for multiline queries. * Favorite queries with optional positional parameters. Save a query using `\fs alias query` and execute it with `\f alias` whenever you need. -* Timing of sql statments and table rendering. +* Timing of sql statements and table rendering. * Config file is automatically created at ``~/.myclirc`` at first launch. * Log every query and its results to a file (disabled by default). * Pretty prints tabular data (with colors!) diff --git a/changelog.md b/changelog.md index 039aa301..b5522d2e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,11 +1,190 @@ -TBD -==== +1.24.4 (2022/03/30) +=================== + +Internal: +--------- +* Upgrade Ubuntu VM for runners as Github has deprecated it + +Bug Fixes: +---------- +* Change in main.py - Replace the `click.get_terminal_size()` with `shutil.get_terminal_size()` + + +1.24.3 (2022/01/20) +=================== + +Bug Fixes: +---------- +* Upgrade cli_helpers to workaround Pygments regression. + + +1.24.2 (2022/01/11) +=================== + +Bug Fixes: +---------- +* Fix autocompletion for more than one JOIN +* Fix the status command when connected to TiDB or other servers that don't implement 'Threads\_connected' +* Pin pygments version to avoid a breaking change + +1.24.1: +======= + +Bug Fixes: +--------- +* Restore dependency on cryptography for the interactive password prompt + +Internal: +--------- +* Deprecate Python mock + + +1.24.0 +====== + +Bug Fixes: +---------- +* Allow `FileNotFound` exception for SSH config files. +* Fix startup error on MySQL < 5.0.22 +* Check error code rather than message for Access Denied error +* Fix login with ~/.my.cnf files + +Features: +--------- +* Add `-g` shortcut to option `--login-path`. +* Alt-Enter dispatches the command in multi-line mode. +* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html) + +Internal: +--------- +* Remove unused function is_open_quote() +* Use importlib, instead of file links, to locate resources +* Test various host-port combinations in command line arguments +* Switched from Cryptography to pyaes for decrypting mylogin.cnf + + +1.23.2 +====== + +Bug Fixes: +---------- +* Ensure `--port` is always an int. + +1.23.1 +====== + +Bug Fixes: +---------- +* Allow `--host` without `--port` to make a TCP connection. + +1.23.0 +====== + +Bug Fixes: +---------- +* Fix config file include logic + +Features: +--------- + +* Add an option `--init-command` to execute SQL after connecting (Thanks: [KITAGAWA Yasutaka]). +* Use InputMode.REPLACE_SINGLE +* Add support for ANSI escape sequences for coloring the prompt. +* Allow customization of Pygments SQL syntax-highlighting styles. +* Add a `\clip` special command to copy queries to the system clipboard. +* Add a special command `\pipe_once` to pipe output to a subprocess. +* Add an option `--charset` to set the default charset when connect database. + +Bug Fixes: +---------- +* Fixed compatibility with sqlparse 0.4 (Thanks: [mtorromeo]). +* Fixed iPython magic (Thanks: [mwcm]). +* Send "Connecting to socket" message to the standard error. +* Respect empty string for prompt_continuation via `prompt_continuation = ''` in `.myclirc` +* Fix \once -o to overwrite output whole, instead of line-by-line. +* Dispatch lines ending with `\e` or `\clip` on return, even in multiline mode. +* Restore working local `--socket=` (Thanks: [xeron]). +* Allow backtick quoting around the database argument to the `use` command. +* Avoid opening `/dev/tty` when `--no-warn` is given. +* Fixed some typo errors in `README.md`. + +1.22.2 +====== + +Bug Fixes: +---------- + +* Make the `pwd` module optional. + +1.22.1 +====== + +Bug Fixes: +---------- +* Fix the breaking change introduced in PyMySQL 0.10.0. (Thanks: [Amjith]). + +Features: +--------- +* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file. +* Add an option `--list-ssh-config` to list ssh configurations. +* Add an option `--ssh-config-path` to choose ssh configuration path. + +Bug Fixes: +---------- + +* Fix specifying empty password with `--password=''` when config file has a password set (Thanks: [Zach DeCook]). + + +1.21.1 +====== + + +Bug Fixes: +---------- + +* Fix broken auto-completion for favorite queries (Thanks: [Amjith]). +* Fix undefined variable exception when running with --no-warn (Thanks: [Georgy Frolov]) +* Support setting color for null value (Thanks: [laixintao]) + +1.21.0 +====== + +Features: +--------- +* Added DSN alias name as a format specifier to the prompt (Thanks: [Georgy Frolov]). +* Mark `update` without `where`-clause as destructive query (Thanks: [Klaus Wünschel]). +* Added DELIMITER command (Thanks: [Georgy Frolov]) +* Added clearer error message when failing to connect to the default socket. +* Extend main.is_dropping_database check with create after delete statement. +* Search `${XDG_CONFIG_HOME}/mycli/myclirc` after `${HOME}/.myclirc` and before `/etc/myclirc` (Thanks: [Takeshi D. Itoh]) + +Bug Fixes: +---------- + +* Allow \o command more than once per session (Thanks: [Georgy Frolov]) +* Fixed crash when the query dropping the current database starts with a comment (Thanks: [Georgy Frolov]) + +Internal: +--------- +* deprecate python versions 2.7, 3.4, 3.5; support python 3.8 + +1.20.1 +====== + +Bug Fixes: +---------- + +* Fix an error when using login paths with an explicit database name (Thanks: [Thomas Roten]). + +1.20.0 +====== Features: ---------- * Auto find alias dsn when `://` not in `database` (Thanks: [QiaoHou Peng]). * Mention URL encoding as escaping technique for special characters in connection DSN (Thanks: [Aljosha Papsch]). * Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: [Amjith Ramanujam]). +* Use a generator to stream the output to the pager (Thanks: [Dick Marinus]). Bug Fixes: ---------- @@ -17,16 +196,10 @@ Bug Fixes: * Update `setup.py` to no longer require `sqlparse` to be less than 0.3.0 as that just came out and there are no notable changes. ([VVelox]) * workaround for ConfigObj parsing strings containing "," as lists (Thanks: [Mike Palandra]) -Features: ---------- - -* Use a generator to stream the output to the pager (Thanks: [Dick Marinus]). - Internal: --------- * fix unhashable FormattedText from prompt toolkit in unit tests (Thanks: [Dick Marinus]). - 1.19.0 ====== @@ -689,25 +862,31 @@ Bug Fixes: ---------- * Fixed the installation issues with PyMySQL dependency on case-sensitive file systems. +[Amjith Ramanujam]: https://blog.amjith.com +[Artem Bezsmertnyi]: https://github.com/mrdeathless +[Carlos Afonso]: https://github.com/afonsocarlos +[Casper Langemeijer]: https://github.com/langemeijer [Daniel West]: http://github.com/danieljwest +[Dick Marinus]: https://github.com/meeuw +[François Pietka]: https://github.com/fpietka +[Frederic Aoustin]: https://github.com/fraoustin +[Georgy Frolov]: https://github.com/pasenor [Irina Truong]: https://github.com/j-bennet -[Amjith Ramanujam]: https://blog.amjith.com +[Jonathan Slenders]: https://github.com/jonathanslenders [Kacper Kwapisz]: https://github.com/KKKas +[laixintao]: https://github.com/laixintao +[Lennart Weller]: https://github.com/lhw [Martijn Engler]: https://github.com/martijnengler [Matheus Rosa]: https://github.com/mdsrosa -[Shoma Suzuki]: https://github.com/shoma -[spacewander]: https://github.com/spacewander -[Thomas Roten]: https://github.com/tsroten -[Artem Bezsmertnyi]: https://github.com/mrdeathless [Mikhail Borisov]: https://github.com/borman -[Casper Langemeijer]: Casper Langemeijer -[Lennart Weller]: https://github.com/lhw +[mtorromeo]: https://github.com/mtorromeo +[mwcm]: https://github.com/mwcm [Phil Cohen]: https://github.com/phlipper +[Scrappy Soft]: https://github.com/scrappysoft +[Shoma Suzuki]: https://github.com/shoma +[spacewander]: https://github.com/spacewander [Terseus]: https://github.com/Terseus +[Thomas Roten]: https://github.com/tsroten [William GARCIA]: https://github.com/willgarcia -[Jonathan Slenders]: https://github.com/jonathanslenders -[Casper Langemeijer]: https://github.com/langemeijer -[Scrappy Soft]: https://github.com/scrappysoft -[Dick Marinus]: https://github.com/meeuw -[François Pietka]: https://github.com/fpietka -[Frederic Aoustin]: https://github.com/fraoustin +[xeron]: https://github.com/xeron +[Zach DeCook]: https://zachdecook.com diff --git a/mycli/AUTHORS b/mycli/AUTHORS index ae1fb08a..d1f3a280 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -15,59 +15,81 @@ Core Developers: Contributors: ------------- - * Steve Robbins - * Shoma Suzuki - * Daniel West - * Scrappy Soft - * Daniel Black - * Jonathan Bruno - * Casper Langemeijer - * Jonathan Slenders + * 0xflotus + * Abirami P + * Adam Chainz + * Aljosha Papsch + * Andy Teijelo Pérez + * Angelo Lupo * Artem Bezsmertnyi - * Mikhail Borisov + * bitkeen + * bjarnagin + * caitinggui + * Carlos Afonso + * Casper Langemeijer + * chainkite + * Colin Caine + * cxbig + * Daniel Black + * Daniel West + * Daniël van Eeden + * François Pietka + * Frederic Aoustin + * Georgy Frolov * Heath Naylor - * Phil Cohen - * spacewander - * Adam Chainz + * Huachao Mao + * Jakub Boukal + * jbruno + * Jerome Provensal + * Jialong Liu * Johannes Hoff + * John Sterling + * Jonathan Bruno + * Jonathan Lloyd + * Jonathan Slenders * Kacper Kwapisz + * Karthikeyan Singaravelan + * kevinhwang91 + * KITAGAWA Yasutaka + * Klaus Wünschel + * laixintao * Lennart Weller * Martijn Engler + * Massimiliano Torromeo + * Michał Górny + * Mike Palandra + * Mikhail Borisov + * Morgan Mitchell + * mrdeathless + * Nathan Huang + * Nicolas Palumbo + * Phil Cohen + * QiaoHou Peng + * Roland Walker + * Ryan Smith + * Scrappy Soft + * Seamile + * Shoma Suzuki + * spacewander + * Steve Robbins + * Takeshi D. Itoh + * Terje Røsten * Terseus * Tyler Kuipers + * ushuz * William GARCIA + * xeron + * Yang Zou * Yasuhiro Matsumoto - * bjarnagin - * jbruno - * mrdeathless - * Abirami P - * John Sterling - * Jialong Liu - * Zhidong - * Daniël van Eeden + * Zach DeCook + * Zane C. Bowers-Hadley * zer09 - * cxbig - * chainkite - * Michał Górny - * Terje Røsten - * Ryan Smith - * Klaus Wünschel - * François Pietka - * Colin Caine - * Frederic Aoustin - * caitinggui - * ushuz * Zhaolong Zhu + * Zhidong * Zhongyang Guan - * Huachao Mao - * QiaoHou Peng - * Yang Zou - * Angelo Lupo - * Aljosha Papsch - * Zane C. Bowers-Hadley - * Mike Palandra + * Arvind Mishra -Creator: --------- +Created by: +----------- Amjith Ramanujam diff --git a/mycli/__init__.py b/mycli/__init__.py index 8ac48f09..e10d6ee2 100644 --- a/mycli/__init__.py +++ b/mycli/__init__.py @@ -1 +1 @@ -__version__ = '1.19.0' +__version__ = '1.24.4' diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index f6cc737a..81353b63 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -1,9 +1,7 @@ -from __future__ import unicode_literals - from prompt_toolkit.enums import DEFAULT_BUFFER from prompt_toolkit.filters import Condition from prompt_toolkit.application import get_app -from .packages.parseutils import is_open_quote +from .packages import special def cli_is_multiline(mycli): @@ -17,6 +15,7 @@ def cond(): return not _multiline_exception(doc.text) return cond + def _multiline_exception(text): orig = text text = text.strip() @@ -27,12 +26,30 @@ def _multiline_exception(text): if text.startswith('\\fs'): return orig.endswith('\n') - return (text.startswith('\\') or # Special Command - text.endswith(';') or # Ended with a semi-colon - text.endswith('\\g') or # Ended with \g - text.endswith('\\G') or # Ended with \G - (text == 'exit') or # Exit doesn't need semi-colon - (text == 'quit') or # Quit doesn't need semi-colon - (text == ':q') or # To all the vim fans out there - (text == '') # Just a plain enter without any text - ) + return ( + # Special Command + text.startswith('\\') or + + # Delimiter declaration + text.lower().startswith('delimiter') or + + # Ended with the current delimiter (usually a semi-column) + text.endswith(special.get_current_delimiter()) or + + text.endswith('\\g') or + text.endswith('\\G') or + text.endswith(r'\e') or + text.endswith(r'\clip') or + + # Exit doesn't need semi-column` + (text == 'exit') or + + # Quit doesn't need semi-column + (text == 'quit') or + + # To all teh vim fans out there + (text == ':q') or + + # just a plain enter without any text + (text == '') + ) diff --git a/mycli/clistyle.py b/mycli/clistyle.py index 6f8b03af..b0ac9922 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals - import logging import pygments.styles @@ -36,6 +34,7 @@ Token.Output.Header: 'output.header', Token.Output.OddRow: 'output.odd-row', Token.Output.EvenRow: 'output.even-row', + Token.Output.Null: 'output.null', Token.Prompt: 'prompt', Token.Continuation: 'continuation', } @@ -45,6 +44,36 @@ v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() } +# all tokens that the Pygments MySQL lexer can produce +OVERRIDE_STYLE_TO_TOKEN = { + 'sql.comment': Token.Comment, + 'sql.comment.multi-line': Token.Comment.Multiline, + 'sql.comment.single-line': Token.Comment.Single, + 'sql.comment.optimizer-hint': Token.Comment.Special, + 'sql.escape': Token.Error, + 'sql.keyword': Token.Keyword, + 'sql.datatype': Token.Keyword.Type, + 'sql.literal': Token.Literal, + 'sql.literal.date': Token.Literal.Date, + 'sql.symbol': Token.Name, + 'sql.quoted-schema-object': Token.Name.Quoted, + 'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape, + 'sql.constant': Token.Name.Constant, + 'sql.function': Token.Name.Function, + 'sql.variable': Token.Name.Variable, + 'sql.number': Token.Number, + 'sql.number.binary': Token.Number.Bin, + 'sql.number.float': Token.Number.Float, + 'sql.number.hex': Token.Number.Hex, + 'sql.number.integer': Token.Number.Integer, + 'sql.operator': Token.Operator, + 'sql.punctuation': Token.Punctuation, + 'sql.string': Token.String, + 'sql.string.double-quouted': Token.String.Double, + 'sql.string.escape': Token.String.Escape, + 'sql.string.single-quoted': Token.String.Single, + 'sql.whitespace': Token.Text, +} def parse_pygments_style(token_name, style_object, style_dict): """Parse token type and style string. @@ -109,6 +138,9 @@ def style_factory_output(name, cli_style): elif token in PROMPT_STYLE_TO_TOKEN: token_type = PROMPT_STYLE_TO_TOKEN[token] style.update({token_type: cli_style[token]}) + elif token in OVERRIDE_STYLE_TO_TOKEN: + token_type = OVERRIDE_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) else: # TODO: cli helpers will have to switch to ptk.Style logger.error('Unhandled style / class name: %s', token) diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py index 89e6afa0..eec2978f 100644 --- a/mycli/clitoolbar.py +++ b/mycli/clitoolbar.py @@ -1,8 +1,7 @@ -from __future__ import unicode_literals - from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.application import get_app from prompt_toolkit.enums import EditingMode +from .packages import special def create_toolbar_tokens_func(mycli, show_fish_help): @@ -12,8 +11,13 @@ def get_toolbar_tokens(): result.append(('class:bottom-toolbar', ' ')) if mycli.multi_line: + delimiter = special.get_current_delimiter() result.append( - ('class:bottom-toolbar', ' (Semi-colon [;] will end the line) ')) + ( + 'class:bottom-toolbar', + ' ({} [{}] will end the line) '.format( + 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter) + )) if mycli.multi_line: result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) @@ -44,5 +48,6 @@ def _get_vi_mode(): InputMode.INSERT: 'I', InputMode.NAVIGATION: 'N', InputMode.REPLACE: 'R', + InputMode.REPLACE_SINGLE: 'R', InputMode.INSERT_MULTIPLE: 'M', }[get_app().vi_state.input_mode] diff --git a/mycli/compat.py b/mycli/compat.py index ee1167b0..2ebfe07f 100644 --- a/mycli/compat.py +++ b/mycli/compat.py @@ -1,9 +1,6 @@ -# -*- coding: utf-8 -*- """Platform and Python version compatibility support.""" import sys -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 WIN = sys.platform in ('win32', 'cygwin') diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index e6c8dd07..124068a9 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -36,7 +36,7 @@ def refresh(self, executor, callbacks, completer_options=None): target=self._bg_refresh, args=(executor, callbacks, completer_options), name='completion_refresh') - self._completer_thread.setDaemon(True) + self._completer_thread.daemon = True self._completer_thread.start() return [(None, None, None, 'Auto-completion refresh started in the background.')] diff --git a/mycli/config.py b/mycli/config.py index 43b33394..5d711093 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,15 +1,20 @@ -from __future__ import print_function -import shutil +from copy import copy from io import BytesIO, TextIOWrapper import logging import os from os.path import exists import struct import sys +from typing import Union, IO from configobj import ConfigObj, ConfigObjError -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.backends import default_backend +import pyaes + +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources try: basestring @@ -29,18 +34,27 @@ def log(logger, level, message): print(message, file=sys.stderr) -def read_config_file(f): - """Read a config file.""" +def read_config_file(f, list_values=True): + """Read a config file. + + *list_values* set to `True` is the default behavior of ConfigObj. + Disabling it causes values to not be parsed for lists, + (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are + not unquoted. We are disabling list_values when reading MySQL config files + so we can correctly interpret commas in passwords. + + """ if isinstance(f, basestring): f = os.path.expanduser(f) try: - config = ConfigObj(f, interpolation=False, encoding='utf8') + config = ConfigObj(f, interpolation=False, encoding='utf8', + list_values=list_values) except ConfigObjError as e: - log(logger, logging.ERROR, "Unable to parse line {0} of config file " + log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f)) - log(logger, logging.ERROR, "Using successfully parsed config values.") + log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: log(logger, logging.WARNING, "You don't have permission to read " @@ -50,13 +64,50 @@ def read_config_file(f): return config -def read_config_files(files): +def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: + """Get a list of configuration files that are included into config_path + with !includedir directive. + + "Normal" configs should be passed as file paths. The only exception + is .mylogin which is decoded into a stream. However, it never + contains include directives and so will be ignored by this + function. + + """ + if not isinstance(config_file, str) or not os.path.isfile(config_file): + return [] + included_configs = [] + + try: + with open(config_file) as f: + include_directives = filter( + lambda s: s.startswith('!includedir'), + f + ) + dirs = map(lambda s: s.strip().split()[-1], include_directives) + dirs = filter(os.path.isdir, dirs) + for dir in dirs: + for filename in os.listdir(dir): + if filename.endswith('.cnf'): + included_configs.append(os.path.join(dir, filename)) + except (PermissionError, UnicodeDecodeError): + pass + return included_configs + + +def read_config_files(files, list_values=True): """Read and merge a list of config files.""" - config = ConfigObj() + config = create_default_config(list_values=list_values) + _files = copy(files) + while _files: + _file = _files.pop(0) + _config = read_config_file(_file, list_values=list_values) - for _file in files: - _config = read_config_file(_file) + # expand includes only if we were able to parse config + # (otherwise we'll just encounter the same errors again) + if config is not None: + _files = get_included_configs(_file) + _files if bool(_config) is True: config.merge(_config) config.filename = _config.filename @@ -64,12 +115,21 @@ def read_config_files(files): return config -def write_default_config(source, destination, overwrite=False): +def create_default_config(list_values=True): + import mycli + default_config_file = resources.open_text(mycli, 'myclirc') + return read_config_file(default_config_file, list_values=list_values) + + +def write_default_config(destination, overwrite=False): + import mycli + default_config = resources.read_text(mycli, 'myclirc') destination = os.path.expanduser(destination) if not overwrite and exists(destination): return - shutil.copyfile(source, destination) + with open(destination, 'w') as f: + f.write(default_config) def get_mylogin_cnf_path(): @@ -112,6 +172,58 @@ def open_mylogin_cnf(name): return TextIOWrapper(plaintext) +# TODO reuse code between encryption an decryption +def encrypt_mylogin_cnf(plaintext: IO[str]): + """Encryption of .mylogin.cnf file, analogous to calling + mysql_config_editor. + + Code is based on the python implementation by Kristian Koehntopp + https://github.com/isotopp/mysql-config-coder + + """ + def realkey(key): + """Create the AES key from the login key.""" + rkey = bytearray(16) + for i in range(len(key)): + rkey[i % 16] ^= key[i] + return bytes(rkey) + + def encode_line(plaintext, real_key, buf_len): + aes = pyaes.AESModeOfOperationECB(real_key) + text_len = len(plaintext) + pad_len = buf_len - text_len + pad_chr = bytes(chr(pad_len), "utf8") + plaintext = plaintext.encode() + pad_chr * pad_len + encrypted_text = b''.join( + [aes.encrypt(plaintext[i: i + 16]) + for i in range(0, len(plaintext), 16)] + ) + return encrypted_text + + LOGIN_KEY_LENGTH = 20 + key = os.urandom(LOGIN_KEY_LENGTH) + real_key = realkey(key) + + outfile = BytesIO() + + outfile.write(struct.pack("i", 0)) + outfile.write(key) + + while True: + line = plaintext.readline() + if not line: + break + real_len = len(line) + pad_len = (int(real_len / 16) + 1) * 16 + + outfile.write(struct.pack("i", pad_len)) + x = encode_line(line, real_key, pad_len) + outfile.write(x) + + outfile.seek(0) + return outfile + + def read_and_decrypt_mylogin_cnf(f): """Read and decrypt the contents of .mylogin.cnf. @@ -153,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f): return None rkey = struct.pack('16B', *rkey) - # Create a decryptor object using the key. - decryptor = _get_decryptor(rkey) - # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() + aes = pyaes.AESModeOfOperationECB(rkey) while True: # Read the length of the ciphertext. @@ -168,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f): # Read cipher_len bytes from the file and decrypt. cipher = f.read(cipher_len) - plain = _remove_pad(decryptor.update(cipher)) + plain = _remove_pad( + b''.join([aes.decrypt(cipher[i: i + 16]) + for i in range(0, cipher_len, 16)]) + ) if plain is False: continue plaintext.write(plain) @@ -196,18 +309,24 @@ def str_to_bool(s): elif s.lower() in false_values: return False else: - raise ValueError('not a recognized boolean value: %s'.format(s)) + raise ValueError('not a recognized boolean value: {0}'.format(s)) -def _get_decryptor(key): - """Get the AES decryptor.""" - c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) - return c.decryptor() +def strip_matching_quotes(s): + """Remove matching, surrounding quotes from a string. + + This is the same logic that ConfigObj uses when parsing config + values. + + """ + if (isinstance(s, basestring) and len(s) >= 2 and + s[0] == s[-1] and s[0] in ('"', "'")): + s = s[1:-1] + return s def _remove_pad(line): """Remove the pad from the *line*.""" - pad_length = ord(line[-1:]) try: # Determine pad length. pad_length = ord(line[-1:]) @@ -218,7 +337,7 @@ def _remove_pad(line): if pad_length > len(line) or len(set(line[-pad_length:])) != 1: # Pad length should be less than or equal to the length of the - # plaintext. The pad should have a single unqiue byte. + # plaintext. The pad should have a single unique byte. logger.warning('Invalid pad found in login path file.') return False diff --git a/mycli/encodingutils.py b/mycli/encodingutils.py deleted file mode 100644 index d53076f1..00000000 --- a/mycli/encodingutils.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from mycli.compat import PY2 - - -if PY2: - text_type = unicode - binary_type = str -else: - text_type = str - binary_type = bytes - - -def unicode2utf8(arg): - """Convert strings to UTF8-encoded bytes. - - Only in Python 2. In Python 3 the args are expected as unicode. - - """ - - if PY2 and isinstance(arg, text_type): - return arg.encode('utf-8') - return arg - - -def utf8tounicode(arg): - """Convert UTF8-encoded bytes to strings. - - Only in Python 2. In Python 3 the errors are returned as strings. - - """ - - if PY2 and isinstance(arg, binary_type): - return arg.decode('utf-8') - return arg diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py index 53ff55ef..4a24c82b 100644 --- a/mycli/key_bindings.py +++ b/mycli/key_bindings.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import logging from prompt_toolkit.enums import EditingMode from prompt_toolkit.filters import completion_is_selected @@ -79,8 +78,12 @@ def _(event): @kb.add('escape', 'enter') def _(event): - """Introduces a line break regardless of multi-line mode or not.""" + """Introduces a line break in multi-line mode, or dispatches the + command in single-line mode.""" _logger.debug('Detected alt-enter key.') - event.app.current_buffer.insert_text('\n') + if mycli.multi_line: + event.app.current_buffer.validate_and_handle() + else: + event.app.current_buffer.insert_text('\n') return kb diff --git a/mycli/magic.py b/mycli/magic.py index 5527f72d..aad229a5 100644 --- a/mycli/magic.py +++ b/mycli/magic.py @@ -19,7 +19,7 @@ def load_ipython_extension(ipython): def mycli_line_magic(line): _logger.debug('mycli magic called: %r', line) parsed = sql.parse.parse(line, {}) - conn = sql.connection.Connection.get(parsed['connection']) + conn = sql.connection.Connection(parsed['connection']) try: # A corresponding mycli object already exists @@ -30,7 +30,7 @@ def mycli_line_magic(line): u = conn.session.engine.url _logger.debug('New mycli: %r', str(u)) - mycli.connect(u.database, u.host, u.username, u.port, u.password) + mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None) conn._mycli = mycli # For convenience, print the connection alias diff --git a/mycli/main.py b/mycli/main.py index 968a6b0d..c13ed780 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,18 +1,22 @@ -from __future__ import unicode_literals -from __future__ import print_function - +from collections import defaultdict +from io import open import os import sys +import shutil import traceback import logging import threading import re +import stat import fileinput from collections import namedtuple +try: + from pwd import getpwuid +except ImportError: + pass from time import time from datetime import datetime from random import choice -from io import open from pymysql import OperationalError from cli_helpers.tabular_output import TabularOutputFormatter @@ -20,11 +24,14 @@ from cli_helpers.utils import strip_ansi import click import sqlparse +from mycli.packages.parseutils import is_dropping_database, is_destructive from prompt_toolkit.completion import DynamicCompleter from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, ConditionalProcessor) from prompt_toolkit.lexers import PygmentsLexer @@ -35,20 +42,21 @@ from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.tabular_output import sql_format from .packages import special +from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func from .clistyle import style_factory, style_factory_output -from .sqlexecute import FIELD_TYPES, SQLExecute +from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher from .config import (write_default_config, get_mylogin_cnf_path, - open_mylogin_cnf, read_config_files, str_to_bool) + open_mylogin_cnf, read_config_files, str_to_bool, + strip_matching_quotes) from .key_bindings import mycli_bindings -from .encodingutils import utf8tounicode, text_type from .lexer import MyCliLexer -from .__init__ import __version__ +from . import __version__ from .compat import WIN -from .packages.filepaths import dir_path_exists +from .packages.filepaths import dir_path_exists, guess_socket_location import itertools @@ -61,16 +69,24 @@ from urllib.parse import urlparse from urllib.parse import unquote +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources try: import paramiko except ImportError: - paramiko = False + from mycli.packages.paramiko_stub import paramiko # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) -PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) +SUPPORT_INFO = ( + 'Home: http://mycli.net\n' + 'Bug tracker: https://github.com/dbcli/mycli/issues' +) class MyCli(object): @@ -84,14 +100,19 @@ class MyCli(object): '/etc/my.cnf', '/etc/mysql/my.cnf', '/usr/local/etc/my.cnf', - '~/.my.cnf' + os.path.expanduser('~/.my.cnf'), ] + # check XDG_CONFIG_HOME exists and not an empty string + if os.environ.get("XDG_CONFIG_HOME"): + xdg_config_home = os.environ.get("XDG_CONFIG_HOME") + else: + xdg_config_home = "~/.config" system_config_files = [ '/etc/myclirc', + os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") ] - default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') pwd_config_file = os.path.join(os.getcwd(), ".myclirc") def __init__(self, sqlexecute=None, prompt=None, @@ -111,15 +132,16 @@ def __init__(self, sqlexecute=None, prompt=None, self.cnf_files = [defaults_file] # Load config. - config_files = ([self.default_config_file] + self.system_config_files + + config_files = (self.system_config_files + [myclirc] + [self.pwd_config_file]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) - special.set_favorite_queries(self.config) + FavoriteQueries.instance = FavoriteQueries.from_config(self.config) + self.dsn_alias = None self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) sql_format.register_new_formatter(self.formatter) @@ -141,8 +163,8 @@ def __init__(self, sqlexecute=None, prompt=None, c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. - if c.filename not in self.system_config_files: - write_default_config(self.default_config_file, myclirc) + if c.filename not in self.system_config_files and not os.path.exists(myclirc): + write_default_config(myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: @@ -161,7 +183,7 @@ def __init__(self, sqlexecute=None, prompt=None, prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt - self.prompt_continuation_format = c['main']['prompt_continuation'] + self.multiline_continuation_char = c['main']['prompt_continuation'] keyword_casing = c['main'].get('keyword_casing', 'auto') self.query_history = [] @@ -220,13 +242,16 @@ def change_table_format(self, arg, **_): yield (None, None, None, msg) def change_db(self, arg, **_): - if arg is '': + if not arg: click.secho( "No database selected", err=True, fg="red" ) return + if arg.startswith('`') and arg.endswith('`'): + arg = re.sub(r'^`(.*)`$', r'\1', arg) + arg = re.sub(r'``', r'`', arg) self.sqlexecute.change_db(arg) yield (None, None, None, 'You are now connected to database "%s" as ' @@ -237,7 +262,7 @@ def execute_from_file(self, arg, **_): message = 'Missing required argument, filename.' return [(None, None, None, message)] try: - with open(os.path.expanduser(arg), encoding='utf-8') as f: + with open(os.path.expanduser(arg)) as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] @@ -308,28 +333,36 @@ def read_my_cnf_files(self, files, keys): :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ - cnf = read_config_files(files) + cnf = read_config_files(files, list_values=False) + + sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + }, + } - sections = ['client'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) - def get(key): - result = None - for sect in cnf: - if sect in sections and key in cnf[sect]: - result = cnf[sect][key] - # HACK: if result is a list, then ConfigObj() probably decoded from - # string by splitting on comma, so reconstruct string by joining on - # comma. - if isinstance(result, list): - result = ','.join(result) - return result + configuration = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if ( + section not in sections or + key not in cnf[section] + ): + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes( + cnf[section][key]) + + return configuration - return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" @@ -357,7 +390,7 @@ def merge_ssl_with_cnf(self, ssl, cnf): def connect(self, database='', user='', passwd='', host='', port='', socket='', charset='', local_infile='', ssl='', ssh_user='', ssh_host='', ssh_port='', - ssh_password='', ssh_key_filename=''): + ssh_password='', ssh_key_filename='', init_command='', password_file=''): cnf = {'database': None, 'user': None, @@ -365,6 +398,7 @@ def connect(self, database='', user='', passwd='', host='', port='', 'host': None, 'port': None, 'socket': None, + 'default_socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, @@ -378,18 +412,24 @@ def connect(self, database='', user='', passwd='', host='', port='', cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - if port or host: - socket = '' - else: - socket = socket or cnf['socket'] user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] port = port or cnf['port'] ssl = ssl or {} - passwd = passwd or cnf['password'] + port = port and int(port) + if not port: + port = 3306 + if not host or host == 'localhost': + socket = ( + cnf['socket'] or + cnf['default_socket'] or + guess_socket_location() + ) + + + passwd = passwd if isinstance(passwd, str) else cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' # Favor whichever local_infile option is set. @@ -406,6 +446,10 @@ def connect(self, database='', user='', passwd='', host='', port='', if not any(v for v in ssl.values()): ssl = None + # if the passwd is not specfied try to set it using the password_file option + password_from_file = self.get_password_from_file(password_file) + passwd = passwd or password_from_file + # Connect to the database. def _connect(): @@ -413,26 +457,29 @@ def _connect(): self.sqlexecute = SQLExecute( database, user, passwd, host, port, socket, charset, local_infile, ssl, ssh_user, ssh_host, ssh_port, - ssh_password, ssh_key_filename + ssh_password, ssh_key_filename, init_command ) except OperationalError as e: - if ('Access denied for user' in e.args[1]): - new_passwd = click.prompt('Password', hide_input=True, - show_default=False, type=str, err=True) + if e.args[0] == ERROR_CODE_ACCESS_DENIED: + if password_from_file: + new_passwd = password_from_file + else: + new_passwd = click.prompt('Password', hide_input=True, + show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( database, user, new_passwd, host, port, socket, charset, local_infile, ssl, ssh_user, ssh_host, - ssh_port, ssh_password, ssh_key_filename + ssh_port, ssh_password, ssh_key_filename, init_command ) else: raise e try: - if (socket is host is port is None) and not WIN: - # Try a sensible default socket first (simplifies auth) - # If we get a connection error, try tcp/ip localhost + if not WIN and socket: + socket_owner = getpwuid(os.stat(socket).st_uid).pw_name + self.echo( + f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) try: - socket = '/var/run/mysqld/mysqld.sock' _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" @@ -441,9 +488,11 @@ def _connect(): self.logger.error( "traceback: %r", traceback.format_exc()) self.logger.debug('Retrying over TCP/IP') + self.echo( + "Failed to connect to local MySQL server through socket '{}':".format(socket)) self.echo(str(e), err=True) self.echo( - 'Failed to connect by socket, retrying over TCP/IP', err=True) + 'Retrying over TCP/IP', err=True) # Else fall back to TCP/IP localhost socket = "" @@ -471,8 +520,19 @@ def _connect(): self.echo(str(e), err=True, fg='red') exit(1) + def get_password_from_file(self, password_file): + password_from_file = None + if password_file: + if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ + and os.access(password_file, os.R_OK): + with open(password_file) as fp: + password_from_file = fp.readline() + password_from_file = password_from_file.rstrip().lstrip() + + return password_from_file + def handle_editor_command(self, text): - """Editor command is any query that is prefixed or suffixed by a '\e'. + r"""Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: @@ -502,6 +562,24 @@ def handle_editor_command(self, text): continue return text + def handle_clip_command(self, text): + r"""A clip command is any query that is prefixed or suffixed by a + '\clip'. + + :param text: Document + :return: Boolean + + """ + + if special.clip_command(text): + query = (special.get_clip_query(text) or + self.get_last_query()) + message = special.copy_query_to_clipboard(sql=query) + if message: + raise RuntimeError(message) + return True + return False + def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute @@ -511,9 +589,6 @@ def run_cli(self): if self.smart_completion: self.refresh_completions() - author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') - sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') - history_file = os.path.expanduser( os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) if dir_path_exists(history_file): @@ -528,21 +603,28 @@ def run_cli(self): key_bindings = mycli_bindings(self) if not self.less_chatty: - print(' '.join(sqlexecute.server_type())) + print(sqlexecute.server_info) print('mycli', __version__) - print('Chat: https://gitter.im/dbcli/mycli') - print('Mail: https://groups.google.com/forum/#!forum/mycli-users') - print('Home: http://mycli.net') - print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) + print(SUPPORT_INFO) + print('Thanks to the contributor -', thanks_picker()) def get_message(): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt('\\d> ') - return [('class:prompt', prompt)] - - def get_continuation(width, line_number, is_soft_wrap): - continuation = ' ' * (width - 1) + ' ' + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, *_): + if self.multiline_continuation_char == '': + continuation = '' + elif self.multiline_continuation_char: + left_padding = width - len(self.multiline_continuation_char) + continuation = " " * \ + max((left_padding - 1), 0) + \ + self.multiline_continuation_char + " " + else: + continuation = " " return [('class:continuation', continuation)] def show_suggestion_tip(): @@ -565,6 +647,15 @@ def one_iteration(text=None): self.echo(str(e), err=True, fg='red') return + try: + if self.handle_clip_command(text): + return + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + return + if not text.strip(): return @@ -577,6 +668,8 @@ def one_iteration(text=None): else: self.echo('Wise choice!') return + else: + destroy = True # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered @@ -635,8 +728,9 @@ def one_iteration(text=None): start = time() result_count += 1 - mutating = mutating or is_mutating(status) + mutating = mutating or destroy or is_mutating(status) special.unset_once_if_written() + special.unset_pipe_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: @@ -750,7 +844,7 @@ def one_iteration(text=None): def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: - click.echo(utf8tounicode(output), file=self.logfile) + click.echo(output, file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. @@ -797,6 +891,7 @@ def output(self, output, status=None): self.log_output(line) special.write_tee(line) special.write_once(line) + special.write_pipe_once(line) if fits or output_via_pager: # buffering @@ -809,8 +904,8 @@ def output(self, output, status=None): if not output_via_pager: # doesn't fit, flush buffer - for line in buf: - click.secho(line) + for buf_line in buf: + click.secho(buf_line) buf = [] else: click.secho(line) @@ -880,7 +975,7 @@ def get_prompt(self, string): string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') + string = string.replace('\\t', sqlexecute.server_info.species.name) string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) @@ -889,6 +984,7 @@ def get_prompt(self, string): string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\p', str(sqlexecute.port)) + string = string.replace('\\A', self.dsn_alias or '(none)') string = string.replace('\\_', ' ') return string @@ -924,8 +1020,8 @@ def format_output(self, title, cur, headers, expanded=False, column_types = None if hasattr(cur, 'description'): def get_col_type(col): - col_type = FIELD_TYPES.get(col[1], text_type) - return col_type if type(col_type) is type else text_type + col_type = FIELD_TYPES.get(col[1], str) + return col_type if type(col_type) is type else str column_types = [get_col_type(col) for col in cur.description] if max_width is not None: @@ -936,19 +1032,19 @@ def get_col_type(col): column_types=column_types, **output_kwargs) - if isinstance(formatted, (text_type)): + if isinstance(formatted, str): formatted = formatted.splitlines() formatted = iter(formatted) - first_line = strip_ansi(next(formatted)) - formatted = itertools.chain([first_line], formatted) - - if (not expanded and max_width and headers and cur and - len(first_line) > max_width): - formatted = self.formatter.format_output( - cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) - if isinstance(formatted, (text_type)): - formatted = iter(formatted.splitlines()) + if (not expanded and max_width and headers and cur): + first_line = next(formatted) + if len(strip_ansi(first_line)) > max_width: + formatted = self.formatter.format_output( + cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + else: + formatted = itertools.chain([first_line], formatted) output = itertools.chain(output, formatted) @@ -959,7 +1055,7 @@ def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 - _, height = click.get_terminal_size() + _, height = shutil.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): @@ -982,6 +1078,9 @@ def get_last_query(self): @click.option('--ssh-port', default=22, help='Port to connect to ssh server.') @click.option('--ssh-password', help='Password to connect to ssh server.') @click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') +@click.option('--ssh-config-path', help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config') +@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') @click.option('--ssl-ca', help='CA file in PEM format.', type=click.Path(exists=True)) @click.option('--ssl-capath', help='CA directory.') @@ -1003,6 +1102,8 @@ def get_last_query(self): help='Use DSN configured into the [alias_dsn] section of myclirc file.') @click.option('--list-dsn', 'list_dsn', is_flag=True, help='list of DSN configured into the [alias_dsn] section of myclirc file.') +@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).') @click.option('-R', '--prompt', 'prompt', help='Prompt format (Default: "{0}").'.format( MyCli.default_prompt)) @@ -1024,10 +1125,16 @@ def get_last_query(self): help='Warn before running a destructive query.') @click.option('--local-infile', type=bool, help='Enable/disable LOAD DATA LOCAL INFILE.') -@click.option('--login-path', type=str, +@click.option('-g', '--login-path', type=str, help='Read this path from the login file.') @click.option('-e', '--execute', type=str, help='Execute command and quit.') +@click.option('--init-command', type=str, + help='SQL statement to execute after connecting.') +@click.option('--charset', type=str, + help='Character set for MySQL session.') +@click.option('--password-file', type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.') @click.argument('database', default='', nargs=1) def cli(database, user, host, port, socket, password, dbname, version, verbose, prompt, logfile, defaults_group_suffix, @@ -1035,7 +1142,8 @@ def cli(database, user, host, port, socket, password, dbname, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename): + ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, + init_command, charset, password_file): """A MySQL terminal client with auto-completion and syntax highlighting. \b @@ -1072,6 +1180,16 @@ def cli(database, user, host, port, socket, password, dbname, else: click.secho(alias) sys.exit(0) + if list_ssh_config: + ssh_config = read_ssh_config(ssh_config_path) + for host in ssh_config.get_hostnames(): + if verbose: + host_config = ssh_config.lookup(host) + click.secho("{} : {}".format( + host, host_config.get('hostname'))) + else: + click.secho(host) + sys.exit(0) # Choose which ever one has a valid value. database = dbname or database @@ -1089,22 +1207,25 @@ def cli(database, user, host, port, socket, password, dbname, dsn_uri = None - if database and '://' not in database and not any([user, password, host, port]): - dsn = database - database = '' + # Treat the database argument as a DSN alias if we're missing + # other connection information. + if (mycli.config['alias_dsn'] and database and '://' not in database + and not any([user, password, host, port, login_path])): + dsn, database = database, '' if database and '://' in database: - dsn_uri = database - database = '' + dsn_uri, database = database, '' - if dsn is not '': + if dsn: try: dsn_uri = mycli.config['alias_dsn'][dsn] - except KeyError as err: - click.secho('Invalid DSNs found in the config file. ' - 'Please check the "[alias_dsn]" section in myclirc.', - err=True, fg='red') + except KeyError: + click.secho('Could not find the specified DSN in the config file. ' + 'Please check the "[alias_dsn]" section in your ' + 'myclirc.', err=True, fg='red') exit(1) + else: + mycli.dsn_alias = dsn if dsn_uri: uri = urlparse(dsn_uri) @@ -1119,13 +1240,17 @@ def cli(database, user, host, port, socket, password, dbname, if not port: port = uri.port - if not paramiko and ssh_host: - click.secho( - "Cannot use SSH transport because paramiko isn't installed, " - "please install paramiko or don't use --ssh-host=", - err=True, fg="red" - ) - exit(1) + if ssh_config_host: + ssh_config = read_ssh_config( + ssh_config_path + ).lookup(ssh_config_host) + ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') + ssh_user = ssh_user if ssh_user else ssh_config.get('user') + if ssh_config.get('port') and ssh_port == 22: + # port has a default value, overwrite it if it's in the config + ssh_port = int(ssh_config.get('port')) + ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( + 'identityfile', [None])[0] ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) @@ -1142,7 +1267,10 @@ def cli(database, user, host, port, socket, password, dbname, ssh_host=ssh_host, ssh_port=ssh_port, ssh_password=ssh_password, - ssh_key_filename=ssh_key_filename + ssh_key_filename=ssh_key_filename, + init_command=init_command, + charset=charset, + password_file=password_file ) mycli.logger.debug('Launch Params: \n' @@ -1177,14 +1305,15 @@ def cli(database, user, host, port, socket, password, dbname, click.secho('Sorry... :(', err=True, fg='red') exit(1) - try: - sys.stdin = open('/dev/tty') - except (IOError, OSError): - mycli.logger.warning('Unable to open TTY as stdin.') + if mycli.destructive_warning and is_destructive(stdin_text): + try: + sys.stdin = open('/dev/tty') + warn_confirmed = confirm_destructive_query(stdin_text) + except (IOError, OSError): + mycli.logger.warning('Unable to open TTY as stdin.') + if not warn_confirmed: + exit(0) - if (mycli.destructive_warning and - confirm_destructive_query(stdin_text) is False): - exit(0) try: new_line = True @@ -1207,35 +1336,12 @@ def need_completion_refresh(queries): try: first_token = query.split()[0] if first_token.lower() in ('alter', 'create', 'use', '\\r', - '\\u', 'connect', 'drop'): + '\\u', 'connect', 'drop', 'rename'): return True except Exception: return False -def is_dropping_database(queries, dbname): - """Determine if the query is dropping a specific database.""" - if dbname is None: - return False - - def normalize_db_name(db): - return db.lower().strip('`"') - - dbname = normalize_db_name(dbname) - - for query in sqlparse.parse(queries): - if query.get_name() is None: - continue - - first_token = query.token_first(skip_cm=True) - _, second_token = query.token_next(0, skip_cm=True) - database_name = normalize_db_name(query.get_name()) - if (first_token.value.lower() == 'drop' and - second_token.value.lower() in ('database', 'schema') and - database_name == dbname): - return True - - def need_completion_reset(queries): """Determines if the statement is a database switch such as 'use' or '\\u'. When a database is changed the existing completions must be reset before we @@ -1256,9 +1362,10 @@ def is_mutating(status): return False mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', - 'replace', 'truncate', 'load']) + 'replace', 'truncate', 'load', 'rename']) return status.split(None, 1)[0].lower() in mutating + def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: @@ -1266,14 +1373,49 @@ def is_select(status): return status.split(None, 1)[0].lower() == 'select' -def thanks_picker(files=()): +def thanks_picker(): + import mycli + lines = ( + resources.read_text(mycli, 'AUTHORS') + + resources.read_text(mycli, 'SPONSORS') + ).split('\n') + contents = [] - for line in fileinput.input(files=files): - m = re.match('^ *\* (.*)', line) + for line in lines: + m = re.match(r'^ *\* (.*)', line) if m: contents.append(m.group(1)) return choice(contents) +@prompt_register('edit-and-execute-command') +def edit_and_execute(event): + """Different from the prompt-toolkit default, we want to have a choice not + to execute a query after editing, hence validate_and_handle=False.""" + buff = event.current_buffer + buff.open_in_editor(validate_and_handle=False) + + +def read_ssh_config(ssh_config_path): + ssh_config = paramiko.config.SSHConfig() + try: + with open(ssh_config_path) as f: + ssh_config.parse(f) + except FileNotFoundError as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + click.secho( + f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', + err=True, fg='red' + ) + sys.exit(1) + else: + return ssh_config + + if __name__ == "__main__": cli() diff --git a/mycli/myclirc b/mycli/myclirc index 571ffb08..c89caa05 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -41,6 +41,7 @@ table_format = ascii # friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, # fruity. # Screenshots at http://mycli.net/syntax +# Can be further modified in [colors] syntax_style = default # Keybindings: Possible values: emacs, vi. @@ -59,13 +60,15 @@ wider_completion_menu = False # \n - Newline # \P - AM/PM # \p - Port -# \R - The current time, in 24-hour military time (0–23) -# \r - The current time, standard 12-hour time (1–12) +# \R - The current time, in 24-hour military time (0-23) +# \r - The current time, standard 12-hour time (1-12) # \s - Seconds of the current time # \t - Product type (Percona, MySQL, MariaDB) +# \A - DSN alias name (from the [alias_dsn] section) # \u - Username +# \x1b[...m - insert ANSI escape sequence prompt = '\t \u@\h:\d> ' -prompt_continuation = '-> ' +prompt_continuation = '->' # Skip intro info on startup and outro info on exit less_chatty = False @@ -110,6 +113,36 @@ bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" +output.null = "#808080" + +# SQL syntax highlighting overrides +# sql.comment = 'italic #408080' +# sql.comment.multi-line = '' +# sql.comment.single-line = '' +# sql.comment.optimizer-hint = '' +# sql.escape = 'border:#FF0000' +# sql.keyword = 'bold #008000' +# sql.datatype = 'nobold #B00040' +# sql.literal = '' +# sql.literal.date = '' +# sql.symbol = '' +# sql.quoted-schema-object = '' +# sql.quoted-schema-object.escape = '' +# sql.constant = '#880000' +# sql.function = '#0000FF' +# sql.variable = '#19177C' +# sql.number = '#666666' +# sql.number.binary = '' +# sql.number.float = '' +# sql.number.hex = '' +# sql.number.integer = '' +# sql.operator = '#666666' +# sql.punctuation = '' +# sql.string = '#BA2121' +# sql.string.double-quouted = '' +# sql.string.escape = 'bold #BB6622' +# sql.string.single-quoted = '' +# sql.whitespace = '' # Favorite queries. [favorite_queries] diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index bea79274..c7db06cb 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -1,20 +1,8 @@ -from __future__ import print_function -import os -import sys import sqlparse from sqlparse.sql import Comparison, Identifier, Where -from sqlparse.compat import text_type from .parseutils import last_word, extract_tables, find_prev_keyword from .special import parse_special_command -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -if PY3: - string_types = str -else: - string_types = basestring - def suggest_type(full_text, text_before_cursor): """Takes the full_text that is typed so far and also the text before the @@ -64,7 +52,7 @@ def suggest_type(full_text, text_before_cursor): stmt_start, stmt_end = 0, 0 for statement in parsed: - stmt_len = len(text_type(statement)) + stmt_len = len(str(statement)) stmt_start, stmt_end = stmt_end, stmt_end + stmt_len if stmt_end >= current_pos: @@ -123,7 +111,7 @@ def suggest_special(text): def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): - if isinstance(token, string_types): + if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): # If 'token' is a Comparison type such as diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py index 5ebdcd97..79fe26dc 100644 --- a/mycli/packages/filepaths.py +++ b/mycli/packages/filepaths.py @@ -1,13 +1,20 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals -from mycli.encodingutils import text_type import os +import platform + + +if os.name == "posix": + if platform.system() == "Darwin": + DEFAULT_SOCKET_DIRS = ("/tmp",) + else: + DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") +else: + DEFAULT_SOCKET_DIRS = () def list_path(root_dir): """List directory if exists. - :param dir: str + :param root_dir: str :return: list """ @@ -62,10 +69,10 @@ def suggest_path(root_dir): """ if not root_dir: - return [text_type(os.path.abspath(os.sep)), text_type('~'), text_type(os.curdir), text_type(os.pardir)] + return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] if '~' in root_dir: - root_dir = text_type(os.path.expanduser(root_dir)) + root_dir = os.path.expanduser(root_dir) if not os.path.exists(root_dir): root_dir, _ = os.path.split(root_dir) @@ -84,3 +91,16 @@ def dir_path_exists(path): """ return os.path.exists(os.path.dirname(path)) + + +def guess_socket_location(): + """Try to guess the location of the default mysql socket file.""" + socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) + for directory in socket_dirs: + for r, dirs, files in os.walk(directory, topdown=True): + for filename in files: + name, ext = os.path.splitext(filename) + if name.startswith("mysql") and ext in ('.socket', '.sock'): + return os.path.join(r, filename) + dirs[:] = [d for d in dirs if d.startswith("mysql")] + return None diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py new file mode 100644 index 00000000..045b00ea --- /dev/null +++ b/mycli/packages/paramiko_stub/__init__.py @@ -0,0 +1,28 @@ +"""A module to import instead of paramiko when it is not available (to avoid +checking for paramiko all over the place). + +When paramiko is first envoked, it simply shuts down mycli, telling +user they either have to install paramiko or should not use SSH +features. + +""" + + +class Paramiko: + def __getattr__(self, name): + import sys + from textwrap import dedent + print(dedent(""" + To enable certain SSH features you need to install paramiko: + + pip install paramiko + + It is required for the following configuration options: + --list-ssh-config + --ssh-config-host + --ssh-host + """)) + sys.exit(1) + + +paramiko = Paramiko() diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 3e0f2e70..d47f59a5 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -1,4 +1,3 @@ -from __future__ import print_function import re import sqlparse from sqlparse.sql import IdentifierList, Identifier, Function @@ -12,11 +11,12 @@ # This matches everything except spaces, parens, colon, comma, and period 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), # This matches everything except a space. - 'all_punctuations': re.compile('([^\s]+)$'), - } + 'all_punctuations': re.compile(r'([^\s]+)$'), +} + def last_word(text, include='alphanum_underscore'): - """ + r""" Find the last word in a sentence. >>> last_word('abc') @@ -81,6 +81,13 @@ def extract_from_part(parsed, stop_at_punctuation=True): yield x elif stop_at_punctuation and item.ttype is Punctuation: return + # Multiple JOINs in the same query won't work properly since + # "ON" is a keyword and will trigger the next elif condition. + # So instead of stooping the loop when finding an "ON" skip it + # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi' + elif item.ttype is Keyword and item.value.upper() == 'ON': + tbl_prefix_seen = False + continue # An incomplete nested select won't be recognized correctly as a # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # the second FROM to trigger this elif condition resulting in a @@ -203,20 +210,57 @@ def queries_start_with(queries, prefixes): return False +def query_has_where_clause(query): + """Check if the query contains a where-clause.""" + return any( + isinstance(token, sqlparse.sql.Where) + for token_list in sqlparse.parse(query) + for token in token_list + ) + + def is_destructive(queries): """Returns if any of the queries in *queries* is destructive.""" keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') - return queries_start_with(queries, keywords) - - -def is_open_quote(sql): - """Returns true if the query contains an unclosed quote.""" + for query in sqlparse.split(queries): + if query: + if query_starts_with(query, keywords) is True: + return True + elif query_starts_with( + query, ['update'] + ) is True and not query_has_where_clause(query): + return True - # parsed can contain one or more semi-colon separated commands - parsed = sqlparse.parse(sql) - return any(_parsed_is_open_quote(p) for p in parsed) + return False if __name__ == '__main__': sql = 'select * from (select t. from tabl t' print (extract_tables(sql)) + + +def is_dropping_database(queries, dbname): + """Determine if the query is dropping a specific database.""" + result = False + if dbname is None: + return False + + def normalize_db_name(db): + return db.lower().strip('`"') + + dbname = normalize_db_name(dbname) + + for query in sqlparse.parse(queries): + keywords = [t for t in query.tokens if t.is_keyword] + if len(keywords) < 2: + continue + if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( + "database", + "schema", + ): + database_token = next( + (t for t in query.tokens if isinstance(t, Identifier)), None + ) + if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: + result = keywords[0].normalized == "DROP" + return result diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py index 138cef38..fb1e431a 100644 --- a/mycli/packages/prompt_utils.py +++ b/mycli/packages/prompt_utils.py @@ -1,12 +1,28 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - - import sys import click from .parseutils import is_destructive +class ConfirmBoolParamType(click.ParamType): + name = 'confirmation' + + def convert(self, value, param, ctx): + if isinstance(value, bool): + return bool(value) + value = value.lower() + if value in ('yes', 'y'): + return True + elif value in ('no', 'n'): + return False + self.fail('%s is not a valid boolean' % value, param, ctx) + + def __repr__(self): + return 'BOOL' + + +BOOLEAN_TYPE = ConfirmBoolParamType() + + def confirm_destructive_query(queries): """Check if the query is destructive and prompts the user to confirm. @@ -19,7 +35,7 @@ def confirm_destructive_query(queries): prompt_text = ("You're about to run a destructive command.\n" "Do you want to proceed? (y/n)") if is_destructive(queries) and sys.stdin.isatty(): - return prompt(prompt_text, type=bool) + return prompt(prompt_text, type=BOOLEAN_TYPE) def confirm(*args, **kwargs): diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index d29507a6..45d70690 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import logging import os import platform @@ -136,23 +135,25 @@ def status(cur, **_): else: output.append(('UNIX socket:', variables['socket'])) - output.append(('Uptime:', format_uptime(status['Uptime']))) - - # Print the current server statistics. - stats = [] - stats.append('Connections: {0}'.format(status['Threads_connected'])) - if 'Queries' in status: - stats.append('Queries: {0}'.format(status['Queries'])) - stats.append('Slow queries: {0}'.format(status['Slow_queries'])) - stats.append('Opens: {0}'.format(status['Opened_tables'])) - stats.append('Flush tables: {0}'.format(status['Flush_commands'])) - stats.append('Open tables: {0}'.format(status['Open_tables'])) - if 'Queries' in status: - queries_per_second = int(status['Queries']) / int(status['Uptime']) - stats.append('Queries per second avg: {:.3f}'.format( - queries_per_second)) - stats = ' '.join(stats) - footer.append('\n' + stats) + if 'Uptime' in status: + output.append(('Uptime:', format_uptime(status['Uptime']))) + + if 'Threads_connected' in status: + # Print the current server statistics. + stats = [] + stats.append('Connections: {0}'.format(status['Threads_connected'])) + if 'Queries' in status: + stats.append('Queries: {0}'.format(status['Queries'])) + stats.append('Slow queries: {0}'.format(status['Slow_queries'])) + stats.append('Opens: {0}'.format(status['Opened_tables'])) + stats.append('Flush tables: {0}'.format(status['Flush_commands'])) + stats.append('Open tables: {0}'.format(status['Open_tables'])) + if 'Queries' in status: + queries_per_second = int(status['Queries']) / int(status['Uptime']) + stats.append('Queries per second avg: {:.3f}'.format( + queries_per_second)) + stats = ' '.join(stats) + footer.append('\n' + stats) footer.append('--------------') return [('\n'.join(title), output, '', '\n'.join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py new file mode 100644 index 00000000..994b134b --- /dev/null +++ b/mycli/packages/special/delimitercommand.py @@ -0,0 +1,80 @@ +import re +import sqlparse + + +class DelimiterCommand(object): + def __init__(self): + self._delimiter = ';' + + def _split(self, sql): + """Temporary workaround until sqlparse.split() learns about custom + delimiters.""" + + placeholder = "\ufffc" # unicode object replacement character + + if self._delimiter == ';': + return sqlparse.split(sql) + + # We must find a string that original sql does not contain. + # Most likely, our placeholder is enough, but if not, keep looking + while placeholder in sql: + placeholder += placeholder[0] + sql = sql.replace(';', placeholder) + sql = sql.replace(self._delimiter, ';') + + split = sqlparse.split(sql) + + return [ + stmt.replace(';', self._delimiter).replace(placeholder, ';') + for stmt in split + ] + + def queries_iter(self, input): + """Iterate over queries in the input string.""" + + queries = self._split(input) + while queries: + for sql in queries: + delimiter = self._delimiter + sql = queries.pop(0) + if sql.endswith(delimiter): + trailing_delimiter = True + sql = sql.strip(delimiter) + else: + trailing_delimiter = False + + yield sql + + # if the delimiter was changed by the last command, + # re-split everything, and if we previously stripped + # the delimiter, append it to the end + if self._delimiter != delimiter: + combined_statement = ' '.join([sql] + queries) + if trailing_delimiter: + combined_statement += delimiter + queries = self._split(combined_statement)[1:] + + def set(self, arg, **_): + """Change delimiter. + + Since `arg` is everything that follows the DELIMITER token + after sqlparse (it may include other statements separated by + the new delimiter), we want to set the delimiter to the first + word of it. + + """ + match = arg and re.search(r'[^\s]+', arg) + if not match: + message = 'Missing required argument, delimiter' + return [(None, None, None, message)] + + delimiter = match.group() + if delimiter.lower() == 'delimiter': + return [(None, None, None, 'Invalid delimiter "delimiter"')] + + self._delimiter = delimiter + return [(None, None, None, "Changed delimiter to {}".format(delimiter))] + + @property + def current(self): + return self._delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py index ed47127f..0b91400e 100644 --- a/mycli/packages/special/favoritequeries.py +++ b/mycli/packages/special/favoritequeries.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - class FavoriteQueries(object): section_name = 'favorite_queries' @@ -26,7 +23,7 @@ class FavoriteQueries(object): ╒════════╤════════╕ │ a │ b │ ╞════════╪════════╡ - │ 日本語 │ 日本語 │ + │ 日本語 │ 日本語 │ ╘════════╧════════╛ # Delete a favorite query. @@ -34,9 +31,16 @@ class FavoriteQueries(object): simple: Deleted ''' + # Class-level variable, for convenience to use as a singleton. + instance = None + def __init__(self, config): self.config = config + @classmethod + def from_config(cls, config): + return FavoriteQueries(config) + def list(self): return self.config.get(self.section_name, []) diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 166e457c..01f3c7ba 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import os import re import locale @@ -9,12 +8,13 @@ from time import sleep import click +import pyperclip import sqlparse -from configobj import ConfigObj from . import export from .main import special_command, NO_QUERY, PARSED_QUERY from .favoritequeries import FavoriteQueries +from .delimitercommand import DelimiterCommand from .utils import handle_cd_command from mycli.packages.prompt_utils import confirm_destructive_query @@ -22,8 +22,12 @@ use_expanded_output = False PAGER_ENABLED = True tee_file = None -once_file = written_to_once_file = None -favoritequeries = FavoriteQueries(ConfigObj()) +once_file = None +written_to_once_file = False +pipe_once_process = None +written_to_pipe_once_process = False +delimiter_command = DelimiterCommand() + @export def set_timing_enabled(val): @@ -36,11 +40,6 @@ def set_pager_enabled(val): PAGER_ENABLED = val -@export -def set_favorite_queries(config): - global favoritequeries - favoritequeries = FavoriteQueries(config) - @export def is_pager_enabled(): return PAGER_ENABLED @@ -119,7 +118,7 @@ def get_editor_query(sql): # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". - pattern = re.compile('(^\\\e|\\\e$)') + pattern = re.compile(r'(^\\e|\\e$)') while pattern.search(sql): sql = pattern.sub('', sql) @@ -148,7 +147,7 @@ def open_external_editor(filename=None, sql=None): if filename: try: - with open(filename, encoding='utf-8') as f: + with open(filename) as f: query = f.read() except IOError: message = 'Error reading file: %s.' % filename @@ -163,6 +162,47 @@ def open_external_editor(filename=None, sql=None): return (query, message) +@export +def clip_command(command): + """Is this a clip command? + + :param command: string + + """ + # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check + # for both conditions. + return command.strip().endswith('\\clip') or command.strip().startswith('\\clip') + + +@export +def get_clip_query(sql): + """Get the query part of a clip command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\clip') is that it strips characters, + # not a substring. So it'll strip "c" in the end of the sql also! + pattern = re.compile(r'(^\\clip|\\clip$)') + while pattern.search(sql): + sql = pattern.sub('', sql) + + return sql + + +@export +def copy_query_to_clipboard(sql=None): + """Send query to the clipboard.""" + + sql = sql or '' + message = None + + try: + pyperclip.copy(u'{sql}'.format(sql=sql)) + except RuntimeError as e: + message = 'Error clipping query: %s.' % e.strerror + + return message + + @special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) def execute_favorite_query(cur, arg, **_): """Returns (title, rows, headers, status)""" @@ -174,7 +214,7 @@ def execute_favorite_query(cur, arg, **_): name, _, arg_str = arg.partition(' ') args = shlex.split(arg_str) - query = favoritequeries.get(name) + query = FavoriteQueries.instance.get(name) if query is None: message = "No favorite query: %s" % (name) yield (None, None, None, message) @@ -198,10 +238,11 @@ def list_favorite_queries(): Returns (title, rows, headers, status)""" headers = ["Name", "Query"] - rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()] + rows = [(r, FavoriteQueries.instance.get(r)) + for r in FavoriteQueries.instance.list()] if not rows: - status = '\nNo favorite queries found.' + favoritequeries.usage + status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage else: status = '' return [('', rows, headers, status)] @@ -216,7 +257,7 @@ def subst_favorite_query_args(query, args): query = query.replace(subst_var, val) - match = re.search('\\$\d+', query) + match = re.search(r'\$\d+', query) if match: return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] @@ -227,7 +268,7 @@ def save_favorite_query(arg, **_): """Save a new favorite query. Returns (title, rows, headers, status)""" - usage = 'Syntax: \\fs name query.\n\n' + favoritequeries.usage + usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] @@ -238,18 +279,18 @@ def save_favorite_query(arg, **_): return [(None, None, None, usage + 'Err: Both name and query are required.')] - favoritequeries.save(name, query) + FavoriteQueries.instance.save(name, query) return [(None, None, None, "Saved.")] + @special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') def delete_favorite_query(arg, **_): - """Delete an existing favorite query. - """ - usage = 'Syntax: \\fd name.\n\n' + favoritequeries.usage + """Delete an existing favorite query.""" + usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] - status = favoritequeries.delete(arg) + status = FavoriteQueries.instance.delete(arg) return [(None, None, None, status)] @@ -261,7 +302,7 @@ def execute_system_command(arg, **_): usage = "Syntax: system [command].\n" if not arg: - return [(None, None, None, usage)] + return [(None, None, None, usage)] try: command = arg.strip() @@ -338,9 +379,14 @@ def write_tee(output): 'Append next result to an output file (overwrite using -o).', aliases=('\\o', )) def set_once(arg, **_): - global once_file + global once_file, written_to_once_file - once_file = parseargfile(arg) + try: + once_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format( + e.filename, e.strerror)) + written_to_once_file = False return [(None, None, None, "")] @@ -349,27 +395,68 @@ def set_once(arg, **_): def write_once(output): global once_file, written_to_once_file if output and once_file: - try: - f = open(**once_file) - except (IOError, OSError) as e: - once_file = None - raise OSError("Cannot write to file '{}': {}".format( - e.filename, e.strerror)) - - with f: - click.echo(output, file=f, nl=False) - click.echo(u"\n", file=f, nl=False) + click.echo(output, file=once_file, nl=False) + click.echo(u"\n", file=once_file, nl=False) + once_file.flush() written_to_once_file = True @export def unset_once_if_written(): """Unset the once file, if it has been written to.""" - global once_file - if written_to_once_file: + global once_file, written_to_once_file + if written_to_once_file and once_file: + once_file.close() once_file = None +@special_command('\\pipe_once', '\\| command', + 'Send next result to a subprocess.', + aliases=('\\|', )) +def set_pipe_once(arg, **_): + global pipe_once_process, written_to_pipe_once_process + pipe_once_cmd = shlex.split(arg) + if len(pipe_once_cmd) == 0: + raise OSError("pipe_once requires a command") + written_to_pipe_once_process = False + pipe_once_process = subprocess.Popen(pipe_once_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + encoding='UTF-8', + universal_newlines=True) + return [(None, None, None, "")] + + +@export +def write_pipe_once(output): + global pipe_once_process, written_to_pipe_once_process + if output and pipe_once_process: + try: + click.echo(output, file=pipe_once_process.stdin, nl=False) + click.echo(u"\n", file=pipe_once_process.stdin, nl=False) + except (IOError, OSError) as e: + pipe_once_process.terminate() + raise OSError( + "Failed writing to pipe_once subprocess: {}".format(e.strerror)) + written_to_pipe_once_process = True + + +@export +def unset_pipe_once_if_written(): + """Unset the pipe_once cmd, if it has been written to.""" + global pipe_once_process, written_to_pipe_once_process + if written_to_pipe_once_process: + (stdout_data, stderr_data) = pipe_once_process.communicate() + if len(stdout_data) > 0: + print(stdout_data.rstrip(u"\n")) + if len(stderr_data) > 0: + print(stderr_data.rstrip(u"\n")) + pipe_once_process = None + written_to_pipe_once_process = False + + @special_command( 'watch', 'watch [seconds] [-c] query', @@ -437,3 +524,20 @@ def watch_query(arg, **kwargs): return finally: set_pager_enabled(old_pager_enabled) + + +@export +@special_command('delimiter', None, 'Change SQL delimiter.') +def set_delimiter(arg, **_): + return delimiter_command.set(arg) + + +@export +def get_current_delimiter(): + return delimiter_command.current + + +@export +def split_queries(input): + for query in delimiter_command.queries_iter(input): + yield query diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index f6e7a115..ab04f30d 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import logging from collections import namedtuple @@ -113,6 +112,8 @@ def quit(*_args): @special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', arg_type=NO_QUERY, case_sensitive=True) +@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.', + arg_type=NO_QUERY, case_sensitive=True) @special_command('\\G', '\\G', 'Display current query results vertically.', arg_type=NO_QUERY, case_sensitive=True) def stub(): diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index b5e43466..e6587bd3 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -1,8 +1,5 @@ -# -*- coding: utf-8 -*- """Format adapter for sql.""" -from __future__ import unicode_literals -from cli_helpers.utils import filter_dict_by_key from mycli.packages.parseutils import extract_tables supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', @@ -11,6 +8,13 @@ preprocessors = () +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return formatter.mycli.sqlexecute.conn.escape(value) + + def adapter(data, headers, table_format=None, **kwargs): tables = extract_tables(formatter.query) if len(tables) > 0: @@ -21,13 +25,13 @@ def adapter(data, headers, table_format=None, **kwargs): table_name = table[1] else: table_name = "`DUAL`" - escape = formatter.mycli.sqlexecute.conn.escape if table_format == 'sql-insert': h = "`, `".join(headers) yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) prefix = " " for d in data: - values = ", ".join(escape(v) for i, v in enumerate(d)) + values = ", ".join(escape_for_sql_statement(v) + for i, v in enumerate(d)) yield "{}({})".format(prefix, values) if prefix == " ": prefix = ", " @@ -41,11 +45,12 @@ def adapter(data, headers, table_format=None, **kwargs): yield "UPDATE {} SET".format(table_name) prefix = " " for i, v in enumerate(d[keys:], keys): - yield "{}`{}` = {}".format(prefix, headers[i], escape(v)) + yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v)) if prefix == " ": prefix = ", " f = "`{}` = {}" - where = (f.format(headers[i], escape(d[i])) for i in range(keys)) + where = (f.format(headers[i], escape_for_sql_statement( + d[i])) for i in range(keys)) yield "WHERE {};".format(" AND ".join(where)) diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 1e11c9c3..3656aa69 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1,5 +1,3 @@ -from __future__ import print_function -from __future__ import unicode_literals import logging from re import compile, escape from collections import Counter @@ -9,7 +7,7 @@ from .packages.completion_engine import suggest_type from .packages.parseutils import last_word from .packages.filepaths import parse_path, complete_path, suggest_path -from .packages.special.iocommands import favoritequeries +from .packages.special.favoritequeries import FavoriteQueries _logger = logging.getLogger(__name__) @@ -21,7 +19,7 @@ class SQLCompleter(Completer): 'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT', 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT', 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT', - 'DELETE FROM', 'DELIMITER', 'DESC', 'DESCRIBE', 'DROP', + 'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP', 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT', 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION', 'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', @@ -61,7 +59,7 @@ def __init__(self, smart_completion=True, supported_formats=(), keyword_casing=' self.reserved_words = set() for x in self.keywords: self.reserved_words.update(x.split()) - self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$") self.special_commands = [] self.table_formats = supported_formats @@ -74,7 +72,7 @@ def escape_name(self, name): if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): - name = '`%s`' % name + name = '`%s`' % name return name @@ -357,7 +355,7 @@ def get_completions(self, document, complete_event, smart_completion=None): completions.extend(special) elif suggestion['type'] == 'favoritequery': queries = self.find_matches(word_before_cursor, - favoritequeries.list(), + FavoriteQueries.instance.list(), start_only=False, fuzzy=True) completions.extend(queries) elif suggestion['type'] == 'table_format': diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 61ba6848..94614387 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,15 +1,17 @@ +import enum import logging +import re + import pymysql -import sqlparse from .packages import special from pymysql.constants import FIELD_TYPE -from pymysql.converters import (convert_mysql_timestamp, convert_datetime, +from pymysql.converters import (convert_datetime, convert_timedelta, convert_date, conversions, decoders) try: import paramiko -except: - paramiko = False +except ImportError: + from mycli.packages.paramiko_stub import paramiko _logger = logging.getLogger(__name__) @@ -18,17 +20,71 @@ FIELD_TYPE.NULL: type(None) }) + +ERROR_CODE_ACCESS_DENIED = 1045 + + +class ServerSpecies(enum.Enum): + MySQL = 'MySQL' + MariaDB = 'MariaDB' + Percona = 'Percona' + Unknown = 'MySQL' + + +class ServerInfo: + def __init__(self, species, version_str): + self.species = species + self.version_str = version_str + self.version = self.calc_mysql_version_value(version_str) + + @staticmethod + def calc_mysql_version_value(version_str) -> int: + if not version_str or not isinstance(version_str, str): + return 0 + try: + major, minor, patch = version_str.split('.') + except ValueError: + return 0 + else: + return int(major) * 10_000 + int(minor) * 100 + int(patch) + + @classmethod + def from_version_string(cls, version_string): + if not version_string: + return cls(ServerSpecies.Unknown, '') + + re_species = ( + (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', + ServerSpecies.Percona), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', + ServerSpecies.MySQL), + ) + for regexp, species in re_species: + match = re.search(regexp, version_string) + if match is not None: + parsed_version = match.group('version') + detected_species = species + break + else: + detected_species = ServerSpecies.Unknown + parsed_version = '' + + return cls(detected_species, parsed_version) + + def __str__(self): + if self.species: + return f'{self.species.value} {self.version_str}' + else: + return self.version_str + + class SQLExecute(object): databases_query = '''SHOW DATABASES''' tables_query = '''SHOW TABLES''' - version_query = '''SELECT @@VERSION''' - - version_comment_query = '''SELECT @@VERSION_COMMENT''' - version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' - show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' @@ -42,7 +98,7 @@ class SQLExecute(object): def __init__(self, database, user, password, host, port, socket, charset, local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename): + ssh_key_filename, init_command=None): self.dbname = database self.user = user self.password = password @@ -52,19 +108,20 @@ def __init__(self, database, user, password, host, port, socket, charset, self.charset = charset self.local_infile = local_infile self.ssl = ssl - self._server_type = None + self.server_info = None self.connection_id = None self.ssh_user = ssh_user self.ssh_host = ssh_host self.ssh_port = ssh_port self.ssh_password = ssh_password self.ssh_key_filename = ssh_key_filename + self.init_command = init_command self.connect() def connect(self, database=None, user=None, password=None, host=None, port=None, socket=None, charset=None, local_infile=None, ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, - ssh_password=None, ssh_key_filename=None): + ssh_password=None, ssh_key_filename=None, init_command=None): db = (database or self.dbname) user = (user or self.user) password = (password or self.password) @@ -79,6 +136,7 @@ def connect(self, database=None, user=None, password=None, host=None, ssh_port = (ssh_port or self.ssh_port) ssh_password = (ssh_password or self.ssh_password) ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) + init_command = (init_command or self.init_command) _logger.debug( 'Connection DB Params: \n' '\tdatabase: %r' @@ -93,13 +151,15 @@ def connect(self, database=None, user=None, password=None, host=None, '\tssh_host: %r' '\tssh_port: %r' '\tssh_password: %r' - '\tssh_key_filename: %r', + '\tssh_key_filename: %r' + '\tinit_command: %r', db, user, host, port, socket, charset, local_infile, ssl, - ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename + ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, + init_command ) conv = conversions.copy() conv.update({ - FIELD_TYPE.TIMESTAMP: lambda obj: (convert_mysql_timestamp(obj) or obj), + FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), @@ -110,15 +170,19 @@ def connect(self, database=None, user=None, password=None, host=None, if ssh_host: defer_connect = True + client_flag = pymysql.constants.CLIENT.INTERACTIVE + if init_command and len(list(special.split_queries(init_command))) > 1: + client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + conn = pymysql.connect( database=db, user=user, password=password, host=host, port=port, unix_socket=socket, use_unicode=True, charset=charset, - autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE, + autocommit=True, client_flag=client_flag, local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", - defer_connect=defer_connect + defer_connect=defer_connect, init_command=init_command ) - if ssh_host and paramiko: + if ssh_host: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) @@ -146,8 +210,10 @@ def connect(self, database=None, user=None, password=None, host=None, self.socket = socket self.charset = charset self.ssl = ssl + self.init_command = init_command # retrieve connection id self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) def run(self, statement): """Execute the sql in the database and return the results. The results @@ -166,12 +232,9 @@ def run(self, statement): if statement.startswith('\\fs'): components = [statement] else: - components = sqlparse.split(statement) + components = special.split_queries(statement) for sql in components: - # Remove spaces, eol and semi-colons. - sql = sql.rstrip(';') - # \G is treated specially since we have to set the expanded output. if sql.endswith('\\G'): special.set_expanded_output(True) @@ -267,37 +330,6 @@ def users(self): for row in cur: yield row - def server_type(self): - if self._server_type: - return self._server_type - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - version = cur.fetchone()[0] - if version[0] == '4': - _logger.debug('Version Comment. sql: %r', - self.version_comment_query_mysql4) - cur.execute(self.version_comment_query_mysql4) - version_comment = cur.fetchone()[1].lower() - if isinstance(version_comment, bytes): - # with python3 this query returns bytes - version_comment = version_comment.decode('utf-8') - else: - _logger.debug('Version Comment. sql: %r', - self.version_comment_query) - cur.execute(self.version_comment_query) - version_comment = cur.fetchone()[0].lower() - - if 'mariadb' in version_comment: - product_type = 'mariadb' - elif 'percona' in version_comment: - product_type = 'percona' - else: - product_type = 'mysql' - - self._server_type = (product_type, version) - return self._server_type - def get_connection_id(self): if not self.connection_id: self.reset_connection_id() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..5422131c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --ignore=mycli/packages/paramiko_stub/__init__.py diff --git a/release.py b/release.py index 18e1b8f2..39df8a3a 100755 --- a/release.py +++ b/release.py @@ -1,8 +1,5 @@ -#!/usr/bin/env python """A script to publish a release of mycli to PyPI.""" -from __future__ import print_function -import io from optparse import OptionParser import re import subprocess @@ -49,7 +46,7 @@ def version(version_file): _version_re = re.compile( r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)') - with io.open(version_file, encoding='utf-8') as f: + with open(version_file) as f: ver = _version_re.search(f.read()).group('version') return ver @@ -75,7 +72,7 @@ def upload_distribution_files(): def push_to_github(): - run_step('git', 'push', 'origin', 'master') + run_step('git', 'push', 'origin', 'main') def push_tags_to_github(): @@ -92,12 +89,6 @@ def checklist(questions): if DEBUG: subprocess.check_output = lambda x: x - checks = ['Have you created the debian package?', - 'Have you updated the AUTHORS file?', - 'Have you updated the `Usage` section of the README?', - ] - checklist(checks) - ver = version('mycli/__init__.py') print('Releasing Version:', ver) diff --git a/requirements-dev.txt b/requirements-dev.txt index 5cbc1577..9c403160 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,12 +1,13 @@ -mock pytest!=3.3.0 pytest-cov==2.4.0 tox twine==1.12.1 -behave -pexpect -coverage==4.3.4 +behave>=1.2.4 +pexpect==3.3 +coverage==5.0.4 codecov==2.0.9 autopep8==1.3.3 +colorama==0.4.1 git+https://github.com/hayd/pep8radius.git # --error-status option not released -click==6.7 +click>=7.0 +paramiko==2.7.1 diff --git a/setup.py b/setup.py index 3081f19c..f79bcd77 100755 --- a/setup.py +++ b/setup.py @@ -10,23 +10,29 @@ _version_re = re.compile(r'__version__\s+=\s+(.*)') -with open('mycli/__init__.py', 'rb') as f: - version = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) +with open('mycli/__init__.py') as f: + version = ast.literal_eval(_version_re.search( + f.read()).group(1)) description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.' install_requirements = [ 'click >= 7.0', - 'Pygments >= 1.6', - 'prompt_toolkit>=2.0.6', + 'cryptography >= 1.0.0', + # 'Pygments>=1.6,<=2.11.1', + 'Pygments>=1.6', + 'prompt_toolkit>=3.0.6,<4.0.0', 'PyMySQL >= 0.9.2', - 'sqlparse>=0.3.0,<0.4.0', + 'sqlparse>=0.3.0,<0.5.0', 'configobj >= 5.0.5', - 'cryptography >= 1.0.0', - 'cli_helpers[styles] > 1.1.0', + 'cli_helpers[styles] >= 2.2.1', + 'pyperclip >= 1.8.1', + 'pyaes >= 1.6.1' ] +if sys.version_info.minor < 9: + install_requirements.append('importlib_resources >= 5.0.0') + class lint(Command): description = 'check code against PEP 8 (and fix violations)' @@ -57,18 +63,26 @@ def run(self): class test(TestCommand): - user_options = [('pytest-args=', 'a', 'Arguments to pass to pytest')] + user_options = [ + ('pytest-args=', 'a', 'Arguments to pass to pytest'), + ('behave-args=', 'b', 'Arguments to pass to pytest') + ] def initialize_options(self): TestCommand.initialize_options(self) self.pytest_args = '' + self.behave_args = '--no-capture' def run_tests(self): unit_test_errno = subprocess.call( - 'pytest ' + self.pytest_args, + 'pytest test/ ' + self.pytest_args, + shell=True + ) + cli_errno = subprocess.call( + 'behave test/features ' + self.behave_args, shell=True ) - cli_errno = subprocess.call('behave test/features', shell=True) + subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False) sys.exit(unit_test_errno or cli_errno) @@ -87,18 +101,16 @@ def run_tests(self): 'console_scripts': ['mycli = mycli.main:cli'], }, cmdclass={'lint': lint, 'test': test}, + python_requires=">=3.6", classifiers=[ 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/conftest.py b/test/conftest.py index 6daf374e..d7d10ce3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,10 @@ import pytest -from utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, - db_connection, SSH_USER, SSH_HOST, SSH_PORT) +from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, + db_connection, SSH_USER, SSH_HOST, SSH_PORT) import mycli.sqlexecute -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def connection(): create_db('_test_db') connection = db_connection('_test_db') diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 00000000..b06935ea --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,35 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature index 0c298b69..f4a7a7f1 100644 --- a/test/features/crud_database.feature +++ b/test/features/crud_database.feature @@ -16,6 +16,10 @@ Feature: manipulate databases: when we connect to dbserver then we see database connected + Scenario: connect and disconnect from quoted test database + When we connect to quoted test database + then we see database connected + Scenario: create and drop default database When we create database then we see database created diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature index d2cc9dd8..3384efd7 100644 --- a/test/features/crud_table.feature +++ b/test/features/crud_table.feature @@ -26,3 +26,24 @@ Feature: manipulate tables: then we see database connected when we select null then we see null selected + + Scenario: confirm destructive query + When we query "create table foo(x integer);" + and we query "delete from foo;" + and we answer the destructive warning with "y" + then we see text "Your call!" + + Scenario: decline destructive query + When we query "delete from foo;" + and we answer the destructive warning with "n" + then we see text "Wise choice!" + + Scenario: no destructive warning if disabled in config + When we run dbcli with --no-warn + and we query "create table blabla(x integer);" + and we query "delete from blabla;" + Then we see text "Query OK" + + Scenario: confirm destructive query with invalid response + When we query "delete from foo;" + then we answer the destructive warning with invalid "1" and see text "is not a valid boolean" diff --git a/test/features/db_utils.py b/test/features/db_utils.py index ef0b42ff..be550e9f 100644 --- a/test/features/db_utils.py +++ b/test/features/db_utils.py @@ -1,15 +1,12 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from __future__ import print_function - import pymysql -def create_db(hostname='localhost', username=None, password=None, - dbname=None): +def create_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): """Create test database. :param hostname: string + :param port: int :param username: string :param password: string :param dbname: string @@ -18,6 +15,7 @@ def create_db(hostname='localhost', username=None, password=None, """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, charset='utf8mb4', @@ -30,14 +28,15 @@ def create_db(hostname='localhost', username=None, password=None, cn.close() - cn = create_cn(hostname, password, username, dbname) + cn = create_cn(hostname, port, password, username, dbname) return cn -def create_cn(hostname, password, username, dbname): +def create_cn(hostname, port, password, username, dbname): """Open connection to database. :param hostname: + :param port: :param password: :param username: :param dbname: string @@ -46,6 +45,7 @@ def create_cn(hostname, password, username, dbname): """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, db=dbname, @@ -56,11 +56,12 @@ def create_cn(hostname, password, username, dbname): return cn -def drop_db(hostname='localhost', username=None, password=None, - dbname=None): +def drop_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): """Drop database. :param hostname: string + :param port: int :param username: string :param password: string :param dbname: string @@ -68,6 +69,7 @@ def drop_db(hostname='localhost', username=None, password=None, """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, db=dbname, diff --git a/test/features/environment.py b/test/features/environment.py index 4d090a99..1ea0f086 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -1,8 +1,5 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from __future__ import print_function - import os +import shutil import sys from tempfile import mkstemp @@ -12,17 +9,39 @@ from steps.wrappers import run_cli, wait_prompt +test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') + + +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + def before_all(context): """Set env parameters.""" os.environ['LINES'] = "100" os.environ['COLUMNS'] = "100" os.environ['EDITOR'] = 'ex' - os.environ['LC_ALL'] = 'en_US.utf8' + os.environ['LC_ALL'] = 'en_US.UTF-8' + os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' + os.environ['MYCLI_HISTFILE'] = os.devnull test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) login_path_file = os.path.join(test_dir, 'mylogin.cnf') - os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file context.package_root = os.path.abspath( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) @@ -33,8 +52,7 @@ def before_all(context): context.exit_sent = False vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get( - 'my_test_db', None) or "mycli_behave_tests" + db_name = get_db_name_from_context(context) db_name_full = '{0}_{1}'.format(db_name, vi) # Store get params from config/environment variables @@ -43,6 +61,10 @@ def before_all(context): 'my_test_host', os.getenv('PYTEST_HOST', 'localhost') ), + 'port': context.config.userdata.get( + 'my_test_port', + int(os.getenv('PYTEST_PORT', '3306')) + ), 'user': context.config.userdata.get( 'my_test_user', os.getenv('PYTEST_USER', 'root') @@ -73,7 +95,8 @@ def before_all(context): context.conf['myclirc'] = os.path.join(context.package_root, 'test', 'myclirc') - context.cn = dbutils.create_db(context.conf['host'], context.conf['user'], + context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], + context.conf['user'], context.conf['pass'], context.conf['dbname']) @@ -83,8 +106,9 @@ def before_all(context): def after_all(context): """Unset env parameters.""" dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['user'], - context.conf['pass'], context.conf['dbname']) + dbutils.drop_db(context.conf['host'], context.conf['port'], + context.conf['user'], context.conf['pass'], + context.conf['dbname']) # Restore env vars. #for k, v in context.pgenv.items(): @@ -98,13 +122,26 @@ def before_step(context, _): context.atprompt = False -def before_scenario(context, _): - run_cli(context) - wait_prompt(context) +def before_scenario(context, arg): + with open(test_log_file, 'w') as f: + f.write('') + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) def after_scenario(context, _): """Cleans up after each test complete.""" + with open(test_log_file) as f: + for line in f: + if 'error' in line.lower(): + raise RuntimeError(f'Error in log file: {line}') if hasattr(context, 'cli') and not context.exit_sent: # Quit nicely. @@ -113,14 +150,26 @@ def after_scenario(context, _): host = context.conf['host'] dbname = context.currentdb context.cli.expect_exact( - '{0}@{1}:{2}> '.format( + '{0}@{1}:{2}>'.format( user, host, dbname ), timeout=5 ) + context.cli.sendcontrol('c') context.cli.sendcontrol('d') context.cli.expect_exact(pexpect.EOF, timeout=5) + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + # TODO: uncomment to debug a failure # def after_step(context, step): # if step.status == "failed": diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt index 657db7da..2c06d5d2 100644 --- a/test/features/fixture_data/help_commands.txt +++ b/test/features/fixture_data/help_commands.txt @@ -2,6 +2,7 @@ | Command | Shortcut | Description | +-------------+----------------------------+------------------------------------------------------------+ | \G | \G | Display current query results vertically. | +| \clip | \clip | Copy query to the system clipboard. | | \dt | \dt[+] [table] | List or describe tables. | | \e | \e | Edit command with editor (uses $EDITOR). | | \f | \f [name [args..]] | List or execute favorite queries. | @@ -9,6 +10,7 @@ | \fs | \fs name query | Save a favorite query. | | \l | \l | List databases. | | \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | +| \pipe_once | \| command | Send next result to a subprocess. | | \timing | \t | Toggle timing of commands. | | connect | \r | Reconnect to the database. Optional database argument. | | exit | \q | Exit. | diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py index a171e34c..f85e0f65 100644 --- a/test/features/fixture_utils.py +++ b/test/features/fixture_utils.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - import os import io @@ -13,7 +10,7 @@ def read_fixture_lines(filename): """ lines = [] - for line in io.open(filename, 'r', encoding='utf8'): + for line in open(filename): lines.append(line.strip()) return lines diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature index 38efbbb0..95366eba 100644 --- a/test/features/iocommands.feature +++ b/test/features/iocommands.feature @@ -2,16 +2,46 @@ Feature: I/O commands Scenario: edit sql in file with external editor When we start external editor providing a file name - and we type sql in the editor + and we type "select * from abc" in the editor and we exit the editor then we see dbcli prompt - and we see the sql in prompt + and we see "select * from abc" in prompt Scenario: tee output from query When we tee output and we wait for prompt - and we query "select 123456" + and we select "select 123456" and we wait for prompt and we notee output and we wait for prompt then we see 123456 in tee output + + Scenario: set delimiter + When we query "delimiter $" + then delimiter is set to "$" + + Scenario: set delimiter twice + When we query "delimiter $" + and we query "delimiter ]]" + then delimiter is set to "]]" + + Scenario: set delimiter and query on same line + When we query "select 123; delimiter $ select 456 $ delimiter %" + then we see result "123" + and we see result "456" + and delimiter is set to "%" + + Scenario: send output to file + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "system cat /tmp/output1.sql" + then we see result "123" + + Scenario: send output to file two times + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "\o /tmp/output2.sql" + and we query "select 456" + and we query "system cat /tmp/output2.sql" + then we see result "456" + \ No newline at end of file diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index e8cb60f0..e1cb26f8 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -1,15 +1,14 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals from textwrap import dedent from behave import then, when import wrappers +from utils import parse_cli_args_to_dict @when('we run dbcli with {arg}') def step_run_cli_with_arg(context, arg): - wrappers.run_cli(context, run_args=arg.split('=')) + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) @when('we execute a small query') @@ -43,5 +42,5 @@ def step_see_large_results(context): '***************************\r\n' + '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) - wrappers.expect_pager(context, expected, timeout=5) + wrappers.expect_pager(context, expected, timeout=10) wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py index 5764b3c6..425ef674 100644 --- a/test/features/steps/basic_commands.py +++ b/test/features/steps/basic_commands.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ -from __future__ import unicode_literals from behave import when from textwrap import dedent @@ -73,9 +71,30 @@ def step_see_found(context): ''') + context.conf['pager_boundary'], timeout=5 ) + + @then(u'we confirm the destructive warning') def step_confirm_destructive_command(context): """Confirm destructive command.""" wrappers.expect_exact( context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) context.cli.sendline('y') + + +@when(u'we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + + +@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + wrappers.expect_exact(context, text, timeout=2) + # we must exit the Click loop, or the feature will hang + context.cli.sendline('n') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 00000000..e16dd867 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,71 @@ +import io +import os +import shlex + +from behave import when, then +import pexpect + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py index 046e829d..841f37d0 100644 --- a/test/features/steps/crud_database.py +++ b/test/features/steps/crud_database.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ -from __future__ import unicode_literals import pexpect @@ -36,7 +34,15 @@ def step_db_connect_test(context): """Send connect to database.""" db_name = context.conf['dbname'] context.currentdb = db_name - context.cli.sendline('use {0}'.format(db_name)) + context.cli.sendline('use {0};'.format(db_name)) + + +@when('we connect to quoted test database') +def step_db_connect_quoted_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use `{0}`;'.format(db_name)) @when('we connect to tmp database') @@ -66,15 +72,13 @@ def step_see_prompt(context): user = context.conf['user'] host = context.conf['host'] dbname = context.currentdb - wrappers.expect_exact(context, '{0}@{1}:{2}> '.format( - user, host, dbname), timeout=5) - context.atprompt = True + wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) @then('we see help output') def step_see_help(context): for expected_line in context.fixture_data['help_commands.txt']: - wrappers.expect_exact(context, expected_line + '\r\n', timeout=1) + wrappers.expect_exact(context, expected_line, timeout=1) @then('we see database created') @@ -98,10 +102,7 @@ def step_see_db_dropped_no_default(context): context.currentdb = None wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) - wrappers.expect_exact(context, '{0}@{1}:{2}> '.format( - user, host, database), timeout=5) - - context.atprompt = True + wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) @then('we see database connected') diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py index 01b2bbf9..f715f0ca 100644 --- a/test/features/steps/crud_table.py +++ b/test/features/steps/crud_table.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ -from __future__ import unicode_literals import wrappers from behave import when, then diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py index 206ca802..bbabf431 100644 --- a/test/features/steps/iocommands.py +++ b/test/features/steps/iocommands.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals import os import wrappers @@ -21,10 +19,10 @@ def step_edit_file(context): wrappers.expect_exact(context, '\r\n:', timeout=2) -@when('we type sql in the editor') -def step_edit_type_sql(context): +@when('we type "{query}" in the editor') +def step_edit_type_sql(context, query): context.cli.sendline('i') - context.cli.sendline('select * from abc') + context.cli.sendline(query) context.cli.sendline('.') wrappers.expect_exact(context, '\r\n:', timeout=2) @@ -35,9 +33,9 @@ def step_edit_quit(context): wrappers.expect_exact(context, "written", timeout=2) -@then('we see the sql in prompt') -def step_edit_done_sql(context): - for match in 'select * from abc'.split(' '): +@then('we see "{query}" in prompt') +def step_edit_done_sql(context, query): + for match in query.split(' '): wrappers.expect_exact(context, match, timeout=5) # Cleanup the command line. context.cli.sendcontrol('c') @@ -56,20 +54,35 @@ def step_tee_ouptut(context): os.path.basename(context.tee_file_name))) -@when(u'we query "select 123456"') -def step_query_select_123456(context): - context.cli.sendline('select 123456') - wrappers.expect_pager(context, dedent("""\ - +--------+\r - | 123456 |\r - +--------+\r - | 123456 |\r - +--------+\r +@when(u'we select "select {param}"') +def step_query_select_number(context, param): + context.cli.sendline(u'select {}'.format(param)) + wrappers.expect_pager(context, dedent(u"""\ + +{dashes}+\r + | {param} |\r + +{dashes}+\r + | {param} |\r + +{dashes}+\r \r - """), timeout=5) + """.format(param=param, dashes='-' * (len(param) + 2)) + ), timeout=5) wrappers.expect_exact(context, '1 row in set', timeout=2) +@then(u'we see result "{result}"') +def step_see_result(context, result): + wrappers.expect_exact( + context, + u"| {} |".format(result), + timeout=2 + ) + + +@when(u'we query "{query}"') +def step_query(context, query): + context.cli.sendline(query) + + @when(u'we notee output') def step_notee_output(context): context.cli.sendline('notee') @@ -81,3 +94,12 @@ def step_see_123456_in_ouput(context): assert '123456' in f.read() if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) + + +@then(u'delimiter is set to "{delimiter}"') +def delimiter_is_set(context, delimiter): + wrappers.expect_exact( + context, + u'Changed delimiter to {}'.format(delimiter), + timeout=2 + ) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py index b82d5f4c..bc1f8663 100644 --- a/test/features/steps/named_queries.py +++ b/test/features/steps/named_queries.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ -from __future__ import unicode_literals import wrappers from behave import when, then diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py index 14a87cc3..e8b99e3e 100644 --- a/test/features/steps/specials.py +++ b/test/features/steps/specials.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ -from __future__ import unicode_literals import wrappers from behave import when, then @@ -17,6 +15,11 @@ def step_refresh_completions(context): context.cli.sendline('rehash') +@then('we see text "{text}"') +def step_see_text(context, text): + """Wait to see given text message.""" + wrappers.expect_exact(context, text, timeout=2) + @then('we see completions refresh started') def step_see_refresh_started(context): """Wait to see refresh output.""" diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 00000000..1ae63d2b --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index 5dfd3800..6408f235 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals - import re import pexpect import sys import textwrap + try: from StringIO import StringIO except ImportError: @@ -16,7 +14,7 @@ def expect_exact(context, expected, timeout): timedout = False try: context.cli.expect_exact(expected, timeout=timeout) - except pexpect.exceptions.TIMEOUT: + except pexpect.TIMEOUT: timedout = True if timedout: # Strip color codes out of the output. @@ -49,21 +47,43 @@ def expect_pager(context, expected, timeout): context.conf['pager_boundary'], expected), timeout=timeout) -def run_cli(context, run_args=None): +def run_cli(context, run_args=None, exclude_args=None): """Run the process using pexpect.""" - run_args = run_args or [] - if context.conf.get('host', None): - run_args.extend(('-h', context.conf['host'])) - if context.conf.get('user', None): - run_args.extend(('-u', context.conf['user'])) - if context.conf.get('pass', None): - run_args.extend(('-p', context.conf['pass'])) - if context.conf.get('dbname', None): - run_args.extend(('-D', context.conf['dbname'])) - if context.conf.get('defaults-file', None): - run_args.extend(('--defaults-file', context.conf['defaults-file'])) - if context.conf.get('myclirc', None): - run_args.extend(('--myclirc', context.conf['myclirc'])) + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + try: cli_cmd = context.conf['cli_command'] except KeyError: @@ -76,7 +96,7 @@ def run_cli(context, run_args=None): '"' ).format(sys.executable) - cmd_parts = [cli_cmd] + run_args + cmd_parts = [cli_cmd] + rendered_args cmd = ' '.join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() @@ -85,11 +105,13 @@ def run_cli(context, run_args=None): context.currentdb = context.conf['dbname'] -def wait_prompt(context): +def wait_prompt(context, prompt=None): """Make sure prompt is displayed.""" - user = context.conf['user'] - host = context.conf['host'] - dbname = context.currentdb - expect_exact(context, '{0}@{1}:{2}> '.format( - user, host, dbname), timeout=5) + if prompt is None: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + prompt = '{0}@{1}:{2}>'.format( + user, host, dbname), + expect_exact(context, prompt, timeout=5) context.atprompt = True diff --git a/test/test_clistyle.py b/test/test_clistyle.py index e18a5303..f82cdf0e 100644 --- a/test/test_clistyle.py +++ b/test/test_clistyle.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Test the mycli.clistyle module.""" import pytest diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py index 98b89e1e..8b06ed38 100644 --- a/test/test_completion_engine.py +++ b/test/test_completion_engine.py @@ -1,5 +1,4 @@ from mycli.packages.completion_engine import suggest_type -import os import pytest @@ -394,6 +393,17 @@ def test_join_using_suggests_common_columns(col_list): 'tables': [(None, 'abc', None), (None, 'def', None)], 'drop_unique': True}] +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', +]) +def test_two_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, + {'type': 'table', 'schema': 'g'}, + {'type': 'view', 'schema': 'g'}, + {'type': 'function', 'schema': 'g'}]) def test_2_statements_2nd_current(): suggestions = suggest_type('select * from a; select * from ', @@ -525,6 +535,13 @@ def test_source_is_file(expression): assert suggestions == [{'type': 'file_name'}] +@pytest.mark.parametrize("expression", [ + "\\f ", +]) +def test_favorite_name_suggestion(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'favoritequery'}] + def test_order_by(): text = 'select * from foo order by ' suggestions = suggest_type(text, text) diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py index 1ed63774..cdc2fb5e 100644 --- a/test/test_completion_refresher.py +++ b/test/test_completion_refresher.py @@ -1,6 +1,6 @@ import time import pytest -from mock import Mock, patch +from unittest.mock import Mock, patch @pytest.fixture diff --git a/test/test_config.py b/test/test_config.py index 81a9ee4f..7f2b2442 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -1,5 +1,5 @@ """Unit tests for the mycli.config module.""" -from io import BytesIO, TextIOWrapper +from io import BytesIO, StringIO, TextIOWrapper import os import struct import sys @@ -7,7 +7,8 @@ import pytest from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, - read_and_decrypt_mylogin_cnf, str_to_bool) + read_and_decrypt_mylogin_cnf, read_config_file, + str_to_bool, strip_matching_quotes) LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), 'mylogin.cnf')) @@ -20,7 +21,6 @@ def open_bmylogin_cnf(name): buf.write(f.read()) return buf - def test_read_mylogin_cnf(): """Tests that a login path file can be read and decrypted.""" mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) @@ -138,3 +138,59 @@ def test_str_to_bool(): with pytest.raises(TypeError): str_to_bool(None) + + +def test_read_config_file_list_values_default(): + """Test that reading a config file uses list_values by default.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f) + + assert config['main']['weather'] == u"cloudy with a chance of meatballs" + + +def test_read_config_file_list_values_off(): + """Test that you can disable list_values when reading a config file.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f, list_values=False) + + assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + + +def test_strip_quotes_with_matching_quotes(): + """Test that a string with matching quotes is unquoted.""" + + s = "May the force be with you." + assert s == strip_matching_quotes('"{}"'.format(s)) + assert s == strip_matching_quotes("'{}'".format(s)) + + +def test_strip_quotes_with_unmatching_quotes(): + """Test that a string with unmatching quotes is not unquoted.""" + + s = "May the force be with you." + assert '"' + s == strip_matching_quotes('"{}'.format(s)) + assert s + "'" == strip_matching_quotes("{}'".format(s)) + + +def test_strip_quotes_with_empty_string(): + """Test that an empty string is handled during unquoting.""" + + assert '' == strip_matching_quotes('') + + +def test_strip_quotes_with_none(): + """Test that None is handled during unquoting.""" + + assert None is strip_matching_quotes(None) + + +def test_strip_quotes_with_quotes(): + """Test that strings with quotes in them are handled during unquoting.""" + + s1 = 'Darth Vader said, "Luke, I am your father."' + assert s1 == strip_matching_quotes(s1) + + s2 = '"Darth Vader said, "Luke, I am your father.""' + assert s2[1:-1] == strip_matching_quotes(s2) diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py index 8b2e909c..21e389ce 100644 --- a/test/test_dbspecial.py +++ b/test/test_dbspecial.py @@ -1,5 +1,5 @@ from mycli.packages.completion_engine import suggest_type -from test_completion_engine import sorted_dicts +from .test_completion_engine import sorted_dicts from mycli.packages.special.utils import format_uptime diff --git a/test/test_main.py b/test/test_main.py index 047d1b71..7731603e 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,11 +1,13 @@ import os +import shutil import click from click.testing import CliRunner -from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT +from mycli.main import MyCli, cli, thanks_picker from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS -from utils import USER, HOST, PORT, PASSWORD, dbtest, run +from mycli.sqlexecute import ServerInfo +from .utils import USER, HOST, PORT, PASSWORD, dbtest, run from textwrap import dedent from collections import namedtuple @@ -13,10 +15,6 @@ from tempfile import NamedTemporaryFile from textwrap import dedent -try: - text_type = basestring -except NameError: - text_type = str test_dir = os.path.abspath(os.path.dirname(__file__)) project_dir = os.path.dirname(test_dir) @@ -144,11 +142,8 @@ def test_batch_mode_csv(executor): def test_thanks_picker_utf8(): - author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') - sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') - - name = thanks_picker((author_file, sponsor_file)) - assert isinstance(name, text_type) + name = thanks_picker() + assert name and isinstance(name, str) def test_help_strings_end_with_periods(): @@ -181,6 +176,7 @@ class TestExecute(): host = 'test' user = 'test' dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') port = 0 def server_type(self): @@ -263,13 +259,13 @@ def test_reserved_space_is_integer(): def stub_terminal_size(): return (5, 5) - old_func = click.get_terminal_size + old_func = shutil.get_terminal_size - click.get_terminal_size = stub_terminal_size + shutil.get_terminal_size = stub_terminal_size mycli = MyCli() assert isinstance(mycli.get_reserved_space(), int) - click.get_terminal_size = old_func + shutil.get_terminal_size = old_func def test_list_dsn(): @@ -287,6 +283,24 @@ def test_list_dsn(): assert result.output == "test : mysql://test/test\n" +def test_list_ssh_config(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + result = runner.invoke(cli, args=args) + assert "test\n" in result.output + result = runner.invoke(cli, args=args + ['--verbose']) + assert "test : test.example.com\n" in result.output + + def test_dsn(monkeypatch): # Setup classes to mock mycli.main.MyCli class Formatter: @@ -388,7 +402,7 @@ def run_query(self, query, new_line=True): MockMyCli.connect_args["port"] == 5 and \ MockMyCli.connect_args["database"] == "arg_database" - # Use a DNS without password + # Use a DSN without password result = runner.invoke(mycli.main.cli, args=[ "mysql://dsn_user@dsn_host:6/dsn_database"] ) @@ -399,3 +413,116 @@ def run_query(self, query, new_line=True): MockMyCli.connect_args["host"] == "dsn_host" and \ MockMyCli.connect_args["port"] == 6 and \ MockMyCli.connect_args["database"] == "dsn_database" + + +def test_ssh_config(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = {'alias_dsn': {}} + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # Setup temporary configuration + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + + # When a user supplies a ssh config. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "joe" and \ + MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ + MockMyCli.connect_args["ssh_port"] == 22222 and \ + MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + "HOME") + "/.ssh/gateway" + + # When a user supplies a ssh config host as argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", "arg_user", + "--ssh-host", "arg_host", + "--ssh-port", "3", + "--ssh-key-filename", "/path/to/key" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "arg_user" and \ + MockMyCli.connect_args["ssh_host"] == "arg_host" and \ + MockMyCli.connect_args["ssh_port"] == 3 and \ + MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + + +@dbtest +def test_init_command_arg(executor): + init_command = "set sql_select_limit=1000" + sql = 'show variables like "sql_select_limit";' + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ["--init-command", init_command], input=sql + ) + + expected = "sql_select_limit\t1000\n" + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_init_command_multiple_arg(executor): + init_command = 'set sql_select_limit=2000; set max_join_size=20000' + sql = ( + 'show variables like "sql_select_limit";\n' + 'show variables like "max_join_size"' + ) + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ['--init-command', init_command], input=sql + ) + + expected_sql_select_limit = 'sql_select_limit\t2000\n' + expected_max_join_size = 'max_join_size\t20000\n' + + assert result.exit_code == 0 + assert expected_sql_select_limit in result.output + assert expected_max_join_size in result.output diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py index 908f9ffe..32b2abdf 100644 --- a/test/test_naive_completion.py +++ b/test/test_naive_completion.py @@ -1,4 +1,3 @@ -from __future__ import unicode_literals import pytest from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document @@ -12,7 +11,7 @@ def completer(): @pytest.fixture def complete_event(): - from mock import Mock + from unittest.mock import Mock return Mock() diff --git a/test/test_parseutils.py b/test/test_parseutils.py index f11dcdb4..920a08db 100644 --- a/test/test_parseutils.py +++ b/test/test_parseutils.py @@ -1,7 +1,7 @@ import pytest from mycli.packages.parseutils import ( - extract_tables, query_starts_with, queries_start_with, is_destructive -) + extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, + is_dropping_database) def test_empty_string(): @@ -135,3 +135,56 @@ def test_is_destructive(): 'drop database foo;' ) assert is_destructive(sql) is True + + +def test_is_destructive_update_with_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1 WHERE id = 1;' + ) + assert is_destructive(sql) is False + + +def test_is_destructive_update_without_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1;' + ) + assert is_destructive(sql) is True + + +@pytest.mark.parametrize( + ('sql', 'has_where_clause'), + [ + ('update test set dummy = 1;', False), + ('update test set dummy = 1 where id = 1);', True), + ], +) +def test_query_has_where_clause(sql, has_where_clause): + assert query_has_where_clause(sql) is has_where_clause + + +@pytest.mark.parametrize( + ('sql', 'dbname', 'is_dropping'), + [ + ('select bar from foo', 'foo', False), + ('drop database "foo";', '`foo`', True), + ('drop schema foo', 'foo', True), + ('drop schema foo', 'bar', False), + ('drop database bar', 'foo', False), + ('drop database foo', None, False), + ('drop database foo; create database foo', 'foo', False), + ('drop database foo; create database bar', 'foo', True), + ('select bar from foo; drop database bazz', 'foo', False), + ('select bar from foo; drop database bazz', 'bazz', True), + ('-- dropping database \n ' + 'drop -- really dropping \n ' + 'schema abc -- now it is dropped', + 'abc', + True) + ] +) +def test_is_dropping_database(sql, dbname, is_dropping): + assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py index 1838f580..2373fac8 100644 --- a/test/test_prompt_utils.py +++ b/test/test_prompt_utils.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - import click from mycli.packages.prompt_utils import confirm_destructive_query diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 0fda3faa..e7d460a8 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -1,7 +1,5 @@ -# coding: utf-8 -from __future__ import unicode_literals import pytest -from mock import patch +from unittest.mock import patch from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document import mycli.packages.special.main as special @@ -37,7 +35,7 @@ def completer(): @pytest.fixture def complete_event(): - from mock import Mock + from unittest.mock import Mock return Mock() diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py index c7f802b1..8b6be337 100644 --- a/test/test_special_iocommands.py +++ b/test/test_special_iocommands.py @@ -1,16 +1,15 @@ -# coding: utf-8 import os import stat import tempfile from time import time -from mock import patch +from unittest.mock import patch import pytest from pymysql import ProgrammingError import mycli.packages.special -from utils import dbtest, db_connection, send_ctrl_c +from .utils import dbtest, db_connection, send_ctrl_c def test_set_get_pager(): @@ -50,7 +49,8 @@ def test_editor_command(): assert mycli.packages.special.get_filename(r'\e filename') == "filename" os.environ['EDITOR'] = 'true' - mycli.packages.special.open_external_editor(r'select 1') == "select 1" + os.environ['VISUAL'] = 'true' + mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" def test_tee_command(): @@ -94,9 +94,8 @@ def test_once_command(): with pytest.raises(TypeError): mycli.packages.special.execute(None, u"\\once") - mycli.packages.special.execute(None, u"\\once /proc/access-denied") with pytest.raises(OSError): - mycli.packages.special.write_once(u"hello world") + mycli.packages.special.execute(None, u"\\once /proc/access-denied") mycli.packages.special.write_once(u"hello world") # write without file set with tempfile.NamedTemporaryFile() as f: @@ -105,9 +104,24 @@ def test_once_command(): assert f.read() == b"hello world\n" mycli.packages.special.execute(None, u"\\once -o " + f.name) - mycli.packages.special.write_once(u"hello world") + mycli.packages.special.write_once(u"hello world line 1") + mycli.packages.special.write_once(u"hello world line 2") f.seek(0) - assert f.read() == b"hello world\n" + assert f.read() == b"hello world line 1\nhello world line 2\n" + + +def test_pipe_once_command(): + with pytest.raises(IOError): + mycli.packages.special.execute(None, u"\\pipe_once") + + with pytest.raises(OSError): + mycli.packages.special.execute( + None, u"\\pipe_once /proc/access-denied") + + mycli.packages.special.execute(None, u"\\pipe_once wc") + mycli.packages.special.write_once(u"hello world") + mycli.packages.special.unset_pipe_once_if_written() + # how to assert on wc output? def test_parseargfile(): @@ -232,3 +246,42 @@ def test_asserts(gen): cur=cur)) test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), cur=cur)) + + +def test_split_sql_by_delimiter(): + for delimiter_str in (';', '$', '😀'): + mycli.packages.special.set_delimiter(delimiter_str) + sql_input = "select 1{} select \ufffc2".format(delimiter_str) + queries = ( + "select 1", + "select \ufffc2" + ) + for query, parsed_query in zip( + queries, mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_switch_delimiter_within_query(): + mycli.packages.special.set_delimiter(';') + sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" + queries = ( + "select 1", + "delimiter $$ select 2 $$ select 3 $$", + "select 2", + "select 3" + ) + for query, parsed_query in zip( + queries, + mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_set_delimiter(): + + for delim in ('foo', 'bar'): + mycli.packages.special.set_delimiter(delim) + assert mycli.packages.special.get_current_delimiter() == delim + + +def teardown_function(): + mycli.packages.special.set_delimiter(';') diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index d0e61662..0f38a97e 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -1,11 +1,10 @@ -# coding=UTF-8 - import os import pytest import pymysql -from utils import run, dbtest, set_expanded_output, is_expanded_output +from mycli.sqlexecute import ServerInfo, ServerSpecies +from .utils import run, dbtest, set_expanded_output, is_expanded_output def assert_result_equal(result, title=None, rows=None, headers=None, @@ -84,7 +83,7 @@ def test_invalid_syntax(executor): @dbtest def test_invalid_column_name(executor): - with pytest.raises(pymysql.InternalError) as excinfo: + with pytest.raises(pymysql.err.OperationalError) as excinfo: run(executor, 'select invalid command') assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) @@ -168,7 +167,7 @@ def test_favorite_query_expanded_output(executor): results = run(executor, "\\fs test-ae select * from test") assert_result_equal(results, status='Saved.') - results = run(executor, "\\f test-ae \G") + results = run(executor, "\\f test-ae \\G") assert is_expanded_output() is True assert_result_equal(results, title='> select * from test', headers=['a'], rows=[('abc',)], auto_status=False) @@ -272,3 +271,24 @@ def test_multiple_results(executor): 'status': '1 row in set'} ] assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py index 357c3257..c20c7de2 100644 --- a/test/test_tabular_output.py +++ b/test/test_tabular_output.py @@ -1,13 +1,11 @@ -# -*- coding: utf-8 -*- """Test the sql output adapter.""" -from __future__ import unicode_literals from textwrap import dedent from mycli.packages.tabular_output import sql_format from cli_helpers.tabular_output import TabularOutputFormatter -from utils import USER, PASSWORD, HOST, PORT, dbtest +from .utils import USER, PASSWORD, HOST, PORT, dbtest import pytest from mycli.main import MyCli @@ -18,20 +16,28 @@ @pytest.fixture def mycli(): cli = MyCli() - cli.connect(None, USER, PASSWORD, HOST, PORT, None) + cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None) return cli @dbtest def test_sql_output(mycli): """Test the sql output adapter.""" - headers = ['letters', 'number', 'optional', 'float'] + headers = ['letters', 'number', 'optional', 'float', 'binary'] class FakeCursor(object): def __init__(self): - self.data = [('abc', 1, None, 10.0), ('d', 456, '1', 0.5)] - self.description = [(None, FIELD_TYPE.VARCHAR), (None, FIELD_TYPE.LONG), - (None, FIELD_TYPE.LONG), (None, FIELD_TYPE.FLOAT)] + self.data = [ + ('abc', 1, None, 10.0, b'\xAA'), + ('d', 456, '1', 0.5, b'\xAA\xBB') + ] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + (None, FIELD_TYPE.BLOB) + ] def __iter__(self): return self @@ -42,8 +48,6 @@ def __next__(self): else: raise StopIteration() - next = __next__ # Python 2 - def description(self): return self.description @@ -52,16 +56,19 @@ def description(self): [(None, None, None, 'Changed table format to sql-update')] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + actual = "\n".join(output) + assert actual == dedent('''\ UPDATE `DUAL` SET `number` = 1 , `optional` = NULL - , `float` = 10 + , `float` = 10.0e0 + , `binary` = X'aa' WHERE `letters` = 'abc'; UPDATE `DUAL` SET `number` = 456 , `optional` = '1' - , `float` = 0.5 + , `float` = 0.5e0 + , `binary` = X'aabb' WHERE `letters` = 'd';''') # Test sql-update-2 output format assert list(mycli.change_table_format("sql-update-2")) == \ @@ -71,11 +78,13 @@ def description(self): assert "\n".join(output) == dedent('''\ UPDATE `DUAL` SET `optional` = NULL - , `float` = 10 + , `float` = 10.0e0 + , `binary` = X'aa' WHERE `letters` = 'abc' AND `number` = 1; UPDATE `DUAL` SET `optional` = '1' - , `float` = 0.5 + , `float` = 0.5e0 + , `binary` = X'aabb' WHERE `letters` = 'd' AND `number` = 456;''') # Test sql-insert output format (without table name) assert list(mycli.change_table_format("sql-insert")) == \ @@ -83,9 +92,9 @@ def description(self): mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) assert "\n".join(output) == dedent('''\ - INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`) VALUES - ('abc', 1, NULL, 10) - , ('d', 456, '1', 0.5) + INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') ;''') # Test sql-insert output format (with table name) assert list(mycli.change_table_format("sql-insert")) == \ @@ -93,9 +102,9 @@ def description(self): mycli.formatter.query = "SELECT * FROM `table`" output = mycli.format_output(None, FakeCursor(), headers) assert "\n".join(output) == dedent('''\ - INSERT INTO `table` (`letters`, `number`, `optional`, `float`) VALUES - ('abc', 1, NULL, 10) - , ('d', 456, '1', 0.5) + INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') ;''') # Test sql-insert output format (with database + table name) assert list(mycli.change_table_format("sql-insert")) == \ @@ -103,7 +112,7 @@ def description(self): mycli.formatter.query = "SELECT * FROM `database`.`table`" output = mycli.format_output(None, FakeCursor(), headers) assert "\n".join(output) == dedent('''\ - INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`) VALUES - ('abc', 1, NULL, 10) - , ('d', 456, '1', 0.5) + INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') ;''') diff --git a/test/utils.py b/test/utils.py index dc7b9de5..66b41940 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- - - import os import time import signal @@ -15,7 +12,7 @@ PASSWORD = os.getenv('PYTEST_PASSWORD') USER = os.getenv('PYTEST_USER', 'root') HOST = os.getenv('PYTEST_HOST', 'localhost') -PORT = os.getenv('PYTEST_PORT', 3306) +PORT = int(os.getenv('PYTEST_PORT', 3306)) CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') SSH_USER = os.getenv('PYTEST_SSH_USER', None) SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) diff --git a/tox.ini b/tox.ini index 630e59a8..612e8b7f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27, py34, py35, py36, py37 +envlist = py36, py37, py38 [testenv] deps = pytest