diff --git a/.coveragerc b/.coveragerc index ae818eef..8d3149f6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,3 @@ [run] -parallel=True -source=mycli +parallel = True +source = mycli diff --git a/mycli/output_formatter/__init__.py b/.git-blame-ignore-revs similarity index 100% rename from mycli/output_formatter/__init__.py rename to .git-blame-ignore-revs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..b678f571 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,64 @@ +name: mycli + +on: + pull_request: + paths-ignore: + - '**.md' + +jobs: + linux: + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + include: + - python-version: 3.6 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.7 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.8 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.9 + os: ubuntu-20.04 # MySQL 8.0.22 + + runs-on: ${{ matrix.os }} + steps: + + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Start MySQL + run: | + sudo /etc/init.d/mysql start + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.txt + pip install --no-cache-dir -e . + + - name: Wait for MySQL connection + run: | + while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do + sleep 5 + done + + - name: Pytest / behave + env: + PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 + run: | + ./setup.py test --pytest-args="--cov-report= --cov=mycli" + + - name: Lint + run: | + ./setup.py lint --branch=HEAD + + - name: Coverage + run: | + coverage combine + coverage report + codecov diff --git a/.gitignore b/.gitignore index 59fa76be..b13429e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,14 @@ .idea/ +.vscode/ /build /dist /mycli.egg-info /src -/tests/behave.ini +/test/behave.ini .vagrant *.pyc *.deb -*.swp .cache/ .coverage .coverage.* diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index f78d94b0..00000000 --- a/.travis.yml +++ /dev/null @@ -1,31 +0,0 @@ -language: python -python: - - "2.7" - - "3.3" - - "3.4" - - "3.5" - - "3.6" - -install: - - pip install PyMySQL . pytest mock codecov pexpect behave - - pip install git+https://github.com/hayd/pep8radius.git - -script: - - coverage run --source mycli -m py.test - - cd tests - - behave - - cd .. - # check for pep8 errors, only looking at branch vs master. If there are errors, show diff and return an error code. - - pep8radius master --docformatter --error-status || ( pep8radius master --docformatter --diff; false ) - -after_success: - - coverage combine - - codecov - -notifications: - webhooks: - urls: - - YOUR_WEBHOOK_URL - on_success: change # options: [always|never|change] default: always - on_failure: always # options: [always|never|change] default: always - on_start: false # default: false diff --git a/AUTHORS.rst b/AUTHORS.rst new file mode 100644 index 00000000..995327f4 --- /dev/null +++ b/AUTHORS.rst @@ -0,0 +1,3 @@ +Check out our `AUTHORS`_. + +.. _AUTHORS: mycli/AUTHORS diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..124b19a6 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,137 @@ +# Development Guide + +This is a guide for developers who would like to contribute to this project. + +If you're interested in contributing to mycli, thank you. We'd love your help! +You'll always get credit for your work. + +## GitHub Workflow + +1. [Fork the repository](https://github.com/dbcli/mycli) on GitHub. + +2. Clone your fork locally: + ```bash + $ git clone + ``` + +3. Add the official repository (`upstream`) as a remote repository: + ```bash + $ git remote add upstream git@github.com:dbcli/mycli.git + ``` + +4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs) + for development: + + ```bash + $ cd mycli + $ pip install virtualenv + $ virtualenv mycli_dev + ``` + + We've just created a virtual environment that we'll use to install all the dependencies + and tools we need to work on mycli. Whenever you want to work on mycli, you + need to activate the virtual environment: + + ```bash + $ source mycli_dev/bin/activate + ``` + + When you're done working, you can deactivate the virtual environment: + + ```bash + $ deactivate + ``` + +5. Install the dependencies and development tools: + + ```bash + $ pip install -r requirements-dev.txt + $ pip install --editable . + ``` + +6. Create a branch for your bugfix or feature based off the `master` branch: + + ```bash + $ git checkout -b master + ``` + +7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: + + ```bash + $ git pull upstream master + ``` + +8. When your work is ready for the mycli team to review it, push your branch to your fork: + + ```bash + $ git push origin + ``` + +9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) + on GitHub. + + +## Running the Tests + +While you work on mycli, it's important to run the tests to make sure your code +hasn't broken any existing functionality. To run the tests, just type in: + +```bash +$ ./setup.py test +``` + +Mycli supports Python 2.7 and 3.4+. You can test against multiple versions of +Python by running tox: + +```bash +$ tox +``` + + +### Test Database Credentials + +The tests require a database connection to work. You can tell the tests which +credentials to use by setting the applicable environment variables: + +```bash +$ export PYTEST_HOST=localhost +$ export PYTEST_USER=user +$ export PYTEST_PASSWORD=myclirocks +$ export PYTEST_PORT=3306 +$ export PYTEST_CHARSET=utf8 +``` + +The default values are `localhost`, `root`, no password, `3306`, and `utf8`. +You only need to set the values that differ from the defaults. + + +### CLI Tests + +Some CLI tests expect the program `ex` to be a symbolic link to `vim`. + +In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will +change the output and therefore make some tests fail. + +You can check this by running: +```bash +$ readlink -f $(which ex) +``` + + +## Coding Style + +Mycli requires code submissions to adhere to +[PEP 8](https://www.python.org/dev/peps/pep-0008/). +It's easy to check the style of your code, just run: + +```bash +$ ./setup.py lint +``` + +If you see any PEP 8 style issues, you can automatically fix them by running: + +```bash +$ ./setup.py lint --fix +``` + +Be sure to commit and push any PEP 8 fixes. diff --git a/DEVELOP.rst b/DEVELOP.rst deleted file mode 100644 index 0d53e6ad..00000000 --- a/DEVELOP.rst +++ /dev/null @@ -1,86 +0,0 @@ -Development Guide ------------------ -This is a guide for developers who would like to contribute to this project. - -GitHub Workflow ---------------- - -If you're interested in contributing to mycli, first of all my heart felt -thanks. `Fork the project `_ in github. Then -clone your fork into your computer (``git clone ``). Make -the changes and create the commits in your local machine. Then push those -changes to your fork. Then click on the pull request icon on github and create -a new pull request. Add a description about the change and send it along. I -promise to review the pull request in a reasonable window of time and get back -to you. - -In order to keep your fork up to date with any changes from mainline, add a new -git remote to your local copy called 'upstream' and point it to the main mycli -repo. - -:: - - $ git remote add upstream git@github.com:dbcli/mycli.git - -Once the 'upstream' end point is added you can then periodically do a ``git -pull upstream master`` to update your local copy and then do a ``git push -origin master`` to keep your own fork up to date. - -Local Setup ------------ - -The installation instructions in the README file are intended for users of -mycli. If you're developing mycli, you'll need to install it in a slightly -different way so you can see the effects of your changes right away without -having to go through the install cycle everytime you change the code. - -It is highly recommended to use virtualenv for development. If you don't know -what a virtualenv is, this `guide `_ -will help you get started. - -Create a virtualenv (let's call it mycli-dev). Activate it: - -:: - - source ./mycli-dev/bin/activate - -Once the virtualenv is activated, `cd` into the local clone of mycli folder -and install mycli using pip as follows: - -:: - - $ pip install --editable . - - or - - $ pip install -e . - -This will install the necessary dependencies as well as install mycli from the -working folder into the virtualenv. By installing it using `pip install -e` -we've linked the mycli installation with the working copy. So any changes made -to the code is immediately available in the installed version of mycli. This -makes it easy to change something in the code, launch mycli and check the -effects of your change. - -Building DEB package from scratch --------------------- - -First pip install `make-deb`. Then run make-deb. It will create a debian folder -after asking a few questions like maintainer name, email etc. - -$ vagrant up - -PEP8 checks ------------ - -When you submit a PR, the changeset is checked for pep8 compliance using -`pep8radius `_. If you see a build failing because -of these checks, install pep8radius and apply style fixes: - -:: - - $ pip install pep8radius - $ pep8radius --docformatter --diff # view a diff of proposed fixes - $ pep8radius --docformatter --in-place # apply the fixes - -Then commit and push the fixes. diff --git a/LICENSE.txt b/LICENSE.txt index 9a41a67d..7b4904e2 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,3 @@ -Copyright (c) 2015, Amjith Ramanujam All rights reserved. Redistribution and use in source and binary forms, with or without modification, diff --git a/MANIFEST.in b/MANIFEST.in index 1d3bbc86..04f4d9a9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,6 @@ -include LICENSE.txt *.md +include LICENSE.txt *.md *.rst requirements-dev.txt screenshots/* +include tasks.py .coveragerc tox.ini +recursive-include test *.cnf +recursive-include test *.feature +recursive-include test *.py +recursive-include test *.txt diff --git a/README.md b/README.md index c12a8eb5..cc04a910 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ # mycli -[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli) -[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli) -[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli) +[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli) +[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python) A command line client for MySQL that can do auto-completion and syntax highlighting. HomePage: [http://mycli.net](http://mycli.net) +Documentation: [http://mycli.net/docs](http://mycli.net/docs) ![Completion](screenshots/tables.png) ![CompletionGif](screenshots/main.gif) @@ -27,7 +28,7 @@ $ pip install -U mycli or ``` -$ brew update && brew install mycli # Only on OS X +$ brew update && brew install mycli # Only on macOS ``` or @@ -41,47 +42,77 @@ $ sudo apt-get install mycli # Only on debian or ubuntu $ mycli --help Usage: mycli [OPTIONS] [DATABASE] + A MySQL terminal client with auto-completion and syntax highlighting. + + Examples: + - mycli my_database + - mycli -u my_user -h my_host.com my_database + - mycli mysql://my_user@my_host.com:3306/my_database + Options: -h, --host TEXT Host address of the database. -P, --port INTEGER Port number to use for connection. Honors - $MYSQL_TCP_PORT + $MYSQL_TCP_PORT. + -u, --user TEXT User name to connect to the database. -S, --socket TEXT The socket file to use for connection. - -p, --password TEXT Password to connect to the database - --pass TEXT Password to connect to the database - --ssl-ca PATH CA file in PEM format - --ssl-capath TEXT CA directory - --ssl-cert PATH X509 cert in PEM format - --ssl-key PATH X509 key in PEM format - --ssl-cipher TEXT SSL cipher to use + -p, --password TEXT Password to connect to the database. + --pass TEXT Password to connect to the database. + --ssh-user TEXT User name to connect to ssh server. + --ssh-host TEXT Host name to connect to ssh server. + --ssh-port INTEGER Port to connect to ssh server. + --ssh-password TEXT Password to connect to ssh server. + --ssh-key-filename TEXT Private key filename (identify file) for the + ssh connection. + + --ssh-config-path TEXT Path to ssh configuration. + --ssh-config-host TEXT Host to connect to ssh server reading from ssh + configuration. + + --ssl-ca PATH CA file in PEM format. + --ssl-capath TEXT CA directory. + --ssl-cert PATH X509 cert in PEM format. + --ssl-key PATH X509 key in PEM format. + --ssl-cipher TEXT SSL cipher to use. --ssl-verify-server-cert Verify server's "Common Name" in its cert against hostname used when connecting. This - option is disabled by default - -v, --version Version of mycli. + option is disabled by default. + + -V, --version Output mycli's version. + -v, --verbose Verbose output. -D, --database TEXT Database to use. - -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ") + -d, --dsn TEXT Use DSN configured into the [alias_dsn] + section of myclirc file. + + --list-dsn list of DSN configured into the [alias_dsn] + section of myclirc file. + + --list-ssh-config list ssh configurations in the ssh config + (requires paramiko). + + -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). -l, --logfile FILENAME Log every query and its results to a file. - --defaults-group-suffix TEXT Read config group with the specified suffix. - --defaults-file PATH Only read default options from the given file + --defaults-group-suffix TEXT Read MySQL config groups with the specified + suffix. + + --defaults-file PATH Only read MySQL options from the given file. --myclirc PATH Location of myclirc file. --auto-vertical-output Automatically switch to vertical output mode if the result is wider than the terminal width. + -t, --table Display batch output in table format. --csv Display batch output in CSV format. --warn / --no-warn Warn before running a destructive query. --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. - --login-path TEXT Read this path from the login file. - -e, --execute TEXT Execute query to the database. + -g, --login-path TEXT Read this path from the login file. + -e, --execute TEXT Execute command and quit. + --init-command TEXT SQL statement to execute after connecting. + --charset TEXT Character set for MySQL session. + --password-file PATH File or FIFO path containing the password + to connect to the db if not specified otherwise --help Show this message and exit. -### Examples - - $ mycli local_database - - $ mycli -h localhost -u root app_db - - $ mycli mysql://amjith@localhost:3306/django_poll Features -------- @@ -95,11 +126,12 @@ Features - `SELECT * FROM ` will only show table names. - `SELECT * FROM users WHERE ` will only show column names. * Support for multiline queries. -* Favorite queries. Save a query using `\fs alias query` and execute it with `\f alias` whenever you need. -* Timing of sql statments and table rendering. +* Favorite queries with optional positional parameters. Save a query using + `\fs alias query` and execute it with `\f alias` whenever you need. +* Timing of sql statements and table rendering. * Config file is automatically created at ``~/.myclirc`` at first launch. * Log every query and its results to a file (disabled by default). -* Pretty prints tabular data. +* Pretty prints tabular data (with colors!) * Support for SSL connections Contributions: @@ -109,7 +141,7 @@ If you're interested in contributing to this project, first of all I would like to extend my heartfelt gratitude. I've written a small doc to describe how to get this running in a development setup. -https://github.com/dbcli/mycli/blob/master/DEVELOP.rst +https://github.com/dbcli/mycli/blob/master/CONTRIBUTING.md Please feel free to reach out to me if you need help. @@ -141,6 +173,16 @@ Once that is installed, you can install mycli as follows: $ sudo pip install mycli ``` +### Windows + +Follow the instructions on this blogpost: https://www.codewall.co.uk/installing-using-mycli-on-windows/ + +### Cygwin + +1. Make sure the following Cygwin packages are installed: +`python3`, `python3-pip`. +2. Install mycli: `pip3 install mycli` + ### Thanks: This project was funded through kickstarter. My thanks to the [backers](http://mycli.net/sponsors) who supported the project. @@ -151,35 +193,25 @@ which is quite literally the backbone library, that made this app possible. Jonathan has also provided valuable feedback and support during the development of this app. -[Click](http://click.pocoo.org/3/) is used for command line option parsing +[Click](http://click.pocoo.org/) is used for command line option parsing and printing error messages. Thanks to [PyMysql](https://github.com/PyMySQL/PyMySQL) for a pure python adapter to MySQL database. -[Tabulate](https://pypi.python.org/pypi/tabulate) library is used for pretty printing the output of tables. - ### Compatibility -Tests have been run on OS X and Linux. +Mycli is tested on macOS and Linux. -THIS HAS NOT BEEN TESTED IN WINDOWS, but the libraries used in this app are Windows compatible. This means it should work without any modifications. If you're unable to run it on Windows, please file a bug. I will try my best to fix it. +**Mycli is not tested on Windows**, but the libraries used in this app are Windows-compatible. +This means it should work without any modifications. If you're unable to run it +on Windows, please [file a bug](https://github.com/dbcli/mycli/issues/new). -### Use with pager (mysql workaround) -As described [here](https://github.com/dbcli/mycli/issues/281), " we only read the [client] section of my.cnf not the [mysql] section". +### Configuration and Usage -So, if you want to use a pager, your .my.cnf file should looks like this: +For more information on using and configuring mycli, [check out our documentation](http://mycli.net/docs). -``` -[mysql] -pager = mypager -[client] -pager = mypager -``` - -instead of just this : - -``` -[mysql] -pager = mypager -``` +Common topics include: +- [Configuring mycli](http://mycli.net/config) +- [Using/Disabling the pager](http://mycli.net/pager) +- [Syntax colors](http://mycli.net/syntax) diff --git a/SPONSORS.rst b/SPONSORS.rst new file mode 100644 index 00000000..173555c3 --- /dev/null +++ b/SPONSORS.rst @@ -0,0 +1,3 @@ +Check out our `SPONSORS`_. + +.. _SPONSORS: mycli/SPONSORS diff --git a/TODO b/TODO deleted file mode 100644 index 64c9842a..00000000 --- a/TODO +++ /dev/null @@ -1,10 +0,0 @@ -# vi: ft=vimwiki - -* [ ] Check if views are available in mysql. -* [ ] Create waffle.io page. -* [ ] Setup gitter. -* [ ] Setup a landing page for mycli.net. -* [ ] Send out invites to backers, pgcli contributors. -* [ ] Write a blog post on personal blog about the experience of kickstarter. -* [ ] Check mycli against MariaDB and Percona. -* [ ] Use error codes instead of matching error strings for reconnect, auto-password prompt etc. diff --git a/Vagrantfile b/Vagrantfile deleted file mode 100644 index a514d1ef..00000000 --- a/Vagrantfile +++ /dev/null @@ -1,30 +0,0 @@ -# -*- mode: ruby -*- -# vi: set ft=ruby : - -Vagrant.configure(2) do |config| - - config.vm.synced_folder ".", "/mycli" - - config.vm.define "debian" do |debian| - debian.vm.box = "debian/jessie64" - debian.vm.provision "shell", inline: <<-SHELL - echo "-> Building DEB" - sudo apt-get update - sudo echo "deb http://ppa.launchpad.net/spotify-jyrki/dh-virtualenv/ubuntu trusty main" >> /etc/apt/sources.list - sudo echo "deb-src http://ppa.launchpad.net/spotify-jyrki/dh-virtualenv/ubuntu trusty main" >> /etc/apt/sources.list - sudo apt-get update - sudo apt-get install -y --force-yes python-virtualenv dh-virtualenv debhelper build-essential python-setuptools python-dev - echo "-> Cleaning up old workspace" - rm -rf build - mkdir -p build - cp -r /mycli build/. - cd build/mycli - - echo "-> Creating mycli deb" - dpkg-buildpackage -us -uc - cp ../*.deb /mycli/. - SHELL - end - -end - diff --git a/changelog.md b/changelog.md index af5c81fa..b5522d2e 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,465 @@ +1.24.4 (2022/03/30) +=================== + +Internal: +--------- +* Upgrade Ubuntu VM for runners as Github has deprecated it + +Bug Fixes: +---------- +* Change in main.py - Replace the `click.get_terminal_size()` with `shutil.get_terminal_size()` + + +1.24.3 (2022/01/20) +=================== + +Bug Fixes: +---------- +* Upgrade cli_helpers to workaround Pygments regression. + + +1.24.2 (2022/01/11) +=================== + +Bug Fixes: +---------- +* Fix autocompletion for more than one JOIN +* Fix the status command when connected to TiDB or other servers that don't implement 'Threads\_connected' +* Pin pygments version to avoid a breaking change + +1.24.1: +======= + +Bug Fixes: +--------- +* Restore dependency on cryptography for the interactive password prompt + +Internal: +--------- +* Deprecate Python mock + + +1.24.0 +====== + +Bug Fixes: +---------- +* Allow `FileNotFound` exception for SSH config files. +* Fix startup error on MySQL < 5.0.22 +* Check error code rather than message for Access Denied error +* Fix login with ~/.my.cnf files + +Features: +--------- +* Add `-g` shortcut to option `--login-path`. +* Alt-Enter dispatches the command in multi-line mode. +* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html) + +Internal: +--------- +* Remove unused function is_open_quote() +* Use importlib, instead of file links, to locate resources +* Test various host-port combinations in command line arguments +* Switched from Cryptography to pyaes for decrypting mylogin.cnf + + +1.23.2 +====== + +Bug Fixes: +---------- +* Ensure `--port` is always an int. + +1.23.1 +====== + +Bug Fixes: +---------- +* Allow `--host` without `--port` to make a TCP connection. + +1.23.0 +====== + +Bug Fixes: +---------- +* Fix config file include logic + +Features: +--------- + +* Add an option `--init-command` to execute SQL after connecting (Thanks: [KITAGAWA Yasutaka]). +* Use InputMode.REPLACE_SINGLE +* Add support for ANSI escape sequences for coloring the prompt. +* Allow customization of Pygments SQL syntax-highlighting styles. +* Add a `\clip` special command to copy queries to the system clipboard. +* Add a special command `\pipe_once` to pipe output to a subprocess. +* Add an option `--charset` to set the default charset when connect database. + +Bug Fixes: +---------- +* Fixed compatibility with sqlparse 0.4 (Thanks: [mtorromeo]). +* Fixed iPython magic (Thanks: [mwcm]). +* Send "Connecting to socket" message to the standard error. +* Respect empty string for prompt_continuation via `prompt_continuation = ''` in `.myclirc` +* Fix \once -o to overwrite output whole, instead of line-by-line. +* Dispatch lines ending with `\e` or `\clip` on return, even in multiline mode. +* Restore working local `--socket=` (Thanks: [xeron]). +* Allow backtick quoting around the database argument to the `use` command. +* Avoid opening `/dev/tty` when `--no-warn` is given. +* Fixed some typo errors in `README.md`. + +1.22.2 +====== + +Bug Fixes: +---------- + +* Make the `pwd` module optional. + +1.22.1 +====== + +Bug Fixes: +---------- +* Fix the breaking change introduced in PyMySQL 0.10.0. (Thanks: [Amjith]). + +Features: +--------- +* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file. +* Add an option `--list-ssh-config` to list ssh configurations. +* Add an option `--ssh-config-path` to choose ssh configuration path. + +Bug Fixes: +---------- + +* Fix specifying empty password with `--password=''` when config file has a password set (Thanks: [Zach DeCook]). + + +1.21.1 +====== + + +Bug Fixes: +---------- + +* Fix broken auto-completion for favorite queries (Thanks: [Amjith]). +* Fix undefined variable exception when running with --no-warn (Thanks: [Georgy Frolov]) +* Support setting color for null value (Thanks: [laixintao]) + +1.21.0 +====== + +Features: +--------- +* Added DSN alias name as a format specifier to the prompt (Thanks: [Georgy Frolov]). +* Mark `update` without `where`-clause as destructive query (Thanks: [Klaus Wünschel]). +* Added DELIMITER command (Thanks: [Georgy Frolov]) +* Added clearer error message when failing to connect to the default socket. +* Extend main.is_dropping_database check with create after delete statement. +* Search `${XDG_CONFIG_HOME}/mycli/myclirc` after `${HOME}/.myclirc` and before `/etc/myclirc` (Thanks: [Takeshi D. Itoh]) + +Bug Fixes: +---------- + +* Allow \o command more than once per session (Thanks: [Georgy Frolov]) +* Fixed crash when the query dropping the current database starts with a comment (Thanks: [Georgy Frolov]) + +Internal: +--------- +* deprecate python versions 2.7, 3.4, 3.5; support python 3.8 + +1.20.1 +====== + +Bug Fixes: +---------- + +* Fix an error when using login paths with an explicit database name (Thanks: [Thomas Roten]). + +1.20.0 +====== + +Features: +---------- +* Auto find alias dsn when `://` not in `database` (Thanks: [QiaoHou Peng]). +* Mention URL encoding as escaping technique for special characters in connection DSN (Thanks: [Aljosha Papsch]). +* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: [Amjith Ramanujam]). +* Use a generator to stream the output to the pager (Thanks: [Dick Marinus]). + +Bug Fixes: +---------- + +* Fix the missing completion for special commands (Thanks: [Amjith Ramanujam]). +* Fix favorites queries being loaded/stored only from/in default config file and not --myclirc (Thanks: [Matheus Rosa]) +* Fix automatic vertical output with native syntax style (Thanks: [Thomas Roten]). +* Update `cli_helpers` version, this will remove quotes from batch output like the official client (Thanks: [Dick Marinus]) +* Update `setup.py` to no longer require `sqlparse` to be less than 0.3.0 as that just came out and there are no notable changes. ([VVelox]) +* workaround for ConfigObj parsing strings containing "," as lists (Thanks: [Mike Palandra]) + +Internal: +--------- +* fix unhashable FormattedText from prompt toolkit in unit tests (Thanks: [Dick Marinus]). + +1.19.0 +====== + +Internal: +--------- + +* Add Python 3.7 trove classifier (Thanks: [Thomas Roten]). +* Fix pytest in Fedora mock (Thanks: [Dick Marinus]). +* Require `prompt_toolkit>=2.0.6` (Thanks: [Dick Marinus]). + +Features: +--------- + +* Add Token.Prompt/Continuation (Thanks: [Dick Marinus]). +* Don't reconnect when switching databases using use (Thanks: [Angelo Lupo]). +* Handle MemoryErrors while trying to pipe in large files and exit gracefully with an error (Thanks: [Amjith Ramanujam]) + +Bug Fixes: +---------- + +* Enable Ctrl-Z to suspend the app (Thanks: [Amjith Ramanujam]). + +1.18.2 +====== + +Bug Fixes: +---------- + +* Fixes database reconnecting feature (Thanks: [Yang Zou]). + +Internal: +--------- + +* Update Twine version to 1.12.1 (Thanks: [Thomas Roten]). +* Fix warnings for running tests on Python 3.7 (Thanks: [Dick Marinus]). +* Clean up and add behave logging (Thanks: [Dick Marinus]). + +1.18.1 +====== + +Features: +--------- + +* Add Keywords: TINYINT, SMALLINT, MEDIUMINT, INT, BIGINT (Thanks: [QiaoHou Peng]). + +Internal: +--------- + +* Update prompt toolkit (Thanks: [Jonathan Slenders], [Irina Truong], [Dick Marinus]). + +1.18.0 +====== + +Features: +--------- + +* Display server version in welcome message (Thanks: [Irina Truong]). +* Set `program_name` connection attribute (Thanks: [Dick Marinus]). +* Use `return` to terminate a generator for better Python 3.7 support (Thanks: [Zhongyang Guan]). +* Add `SAVEPOINT` to SQLCompleter (Thanks: [Huachao Mao]). +* Connect using a SSH transport (Thanks: [Dick Marinus]). +* Add `FROM_UNIXTIME` and `UNIX_TIMESTAMP` to SQLCompleter (Thanks: [QiaoHou Peng]) +* Search `${PWD}/.myclirc`, then `${HOME}/.myclirc`, lastly `/etc/myclirc` (Thanks: [QiaoHao Peng]) + +Bug Fixes: +---------- + +* When DSN is used, allow overrides from mycli arguments (Thanks: [Dick Marinus]). +* A DSN without password should be allowed (Thanks: [Dick Marinus]) + +Bug Fixes: +---------- + +* Convert `sql_format` to unicode strings for py27 compatibility (Thanks: [Dick Marinus]). +* Fixes mycli compatibility with pbr (Thanks: [Thomas Roten]). +* Don't align decimals for `sql_format` (Thanks: [Dick Marinus]). + +Internal: +--------- + +* Use fileinput (Thanks: [Dick Marinus]). +* Enable tests for Python 3.7 (Thanks: [Thomas Roten]). +* Remove `*.swp` from gitignore (Thanks: [Dick Marinus]). + +1.17.0: +======= + +Features: +---------- + +* Add `CONCAT` to SQLCompleter and remove unused code (Thanks: [caitinggui]) +* Do not quit when aborting a confirmation prompt (Thanks: [Thomas Roten]). +* Add option list-dsn (Thanks: [Frederic Aoustin]). +* Add verbose option for list-dsn, add tests and clean up code (Thanks: [Dick Marinus]). + +Bug Fixes: +---------- + +* Add enable_pager to the config file (Thanks: [Frederic Aoustin]). +* Mark `test_sql_output` as a dbtest (Thanks: [Dick Marinus]). +* Don't crash if the log/history file directories don't exist (Thanks: [Thomas Roten]). +* Unquote dsn username and password (Thanks: [Dick Marinus]). +* Output `Password:` prompt to stderr (Thanks: [ushuz]). +* Mark `alter` as a destructive query (Thanks: [Dick Marinus]). +* Quote CSV fields (Thanks: [Thomas Roten]). +* Fix `thanks_picker` (Thanks: [Dick Marinus]). + +Internal: +--------- + +* Refactor Destructive Warning behave tests (Thanks: [Dick Marinus]). + + +1.16.0: +======= + +Features: +--------- + +* Add DSN aliases to the config file (Thanks: [Frederic Aoustin]). + +Bug Fixes: +---------- + +* Do not try to connect to a unix socket on Windows (Thanks: [Thomas Roten]). + +1.15.0: +======= + +Features: +--------- + +* Add sql-update/insert output format. (Thanks: [Dick Marinus]). +* Also complete aliases in WHERE. (Thanks: [Dick Marinus]). + +1.14.0: +======= + +Features: +--------- + +* Add `watch [seconds] query` command to repeat a query every [seconds] seconds (by default 5). (Thanks: [David Caro](https://github.com/Terseus)) +* Default to unix socket connection if host and port are unspecified. This simplifies authentication on some systems and matches mysql behaviour. +* Add support for positional parameters to favorite queries. (Thanks: [Scrappy Soft](https://github.com/scrappysoft)) + +Bug Fixes: +---------- + +* Fix source command for script in current working directory. (Thanks: [Dick Marinus]). +* Fix issue where the `tee` command did not work on Python 2.7 (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Drop support for Python 3.3 (Thanks: [Thomas Roten]). + +* Make tests more compatible between different build environments. (Thanks: [David Caro]) +* Merge `_on_completions_refreshed` and `_swap_completer_objects` functions (Thanks: [Dick Marinus]). + +1.13.1: +======= + +Bug Fixes: +---------- + +* Fix keyword completion suggestion for `SHOW` (Thanks: [Thomas Roten]). +* Prevent mycli from crashing when failing to read login path file (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Make tests ignore user config files (Thanks: [Thomas Roten]). + +1.13.0: +======= + +Features: +--------- + +* Add file name completion for source command (issue #500). (Thanks: [Irina Truong]). + +Bug Fixes: +---------- + +* Fix UnicodeEncodeError when editing sql command in external editor (Thanks: Klaus Wünschel). +* Fix MySQL4 version comment retrieval (Thanks: [François Pietka]) +* Fix error that occurred when outputting JSON and NULL data (Thanks: [Thomas Roten]). + +1.12.1: +======= + +Bug Fixes: +---------- + +* Prevent missing MySQL help database from causing errors in completions (Thanks: [Thomas Roten]). +* Fix mycli from crashing with small terminal windows under Python 2 (Thanks: [Thomas Roten]). +* Prevent an error from displaying when you drop the current database (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Use less memory when formatting results for display (Thanks: [Dick Marinus]). +* Preliminary work for a future change in outputting results that uses less memory (Thanks: [Dick Marinus]). + +1.12.0: +======= + +Features: +--------- + +* Add fish-style auto-suggestion from history. (Thanks: [Amjith Ramanujam]) + + +1.11.0: +======= + +Features: +--------- + +* Handle reserved space for completion menu better in small windows. (Thanks: [Thomas Roten]). +* Display current vi mode in toolbar. (Thanks: [Thomas Roten]). +* Opening an external editor will edit the last-run query. (Thanks: [Thomas Roten]). +* Output once special command. (Thanks: [Dick Marinus]). +* Add special command to show create table statement. (Thanks: [Ryan Smith]) +* Display all result sets returned by stored procedures (Thanks: [Thomas Roten]). +* Add current time to prompt options (Thanks: [Thomas Roten]). +* Output status text in a more intuitive way (Thanks: [Thomas Roten]). +* Add colored/styled headers and odd/even rows (Thanks: [Thomas Roten]). +* Keyword completion casing (upper/lower/auto) (Thanks: [Irina Truong]). + +Bug Fixes: +---------- + +* Fixed incorrect timekeeping when running queries from a file. (Thanks: [Thomas Roten]). +* Do not display time and empty line for blank queries (Thanks: [Thomas Roten]). +* Fixed issue where quit command would sometimes not work (Thanks: [Thomas Roten]). +* Remove shebang from main.py (Thanks: [Dick Marinus]). +* Only use pager if output doesn't fit. (Thanks: [Dick Marinus]). +* Support tilde user directory for output file names (Thanks: [Thomas Roten]). +* Auto vertical output is a little bit better at its calculations (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Rename tests/ to test/. (Thanks: [Dick Marinus]). +* Move AUTHORS and SPONSORS to mycli directory. (Thanks: [Terje Røsten] []). +* Switch from pycryptodome to cryptography (Thanks: [Thomas Roten]). +* Add pager wrapper for behave tests (Thanks: [Dick Marinus]). +* Behave test source command (Thanks: [Dick Marinus]). +* Test using behave the tee command (Thanks: [Dick Marinus]). +* Behave fix clean up. (Thanks: [Dick Marinus]). +* Remove output formatter code in favor of CLI Helpers dependency (Thanks: [Thomas Roten]). +* Better handle common before/after scenarios in behave. (Thanks: [Dick Marinus]) +* Added a regression test for sqlparse >= 0.2.3 (Thanks: [Dick Marinus]). +* Reverted removal of temporary hack for sqlparse (Thanks: [Dick Marinus]). +* Add setup.py commands to simplify development tasks (Thanks: [Thomas Roten]). +* Add behave tests to tox (Thanks: [Dick Marinus]). +* Add missing @dbtest to tests (Thanks: [Dick Marinus]). +* Standardizes punctuation/grammar for help strings (Thanks: [Thomas Roten]). + 1.10.0: ======= @@ -400,22 +862,31 @@ Bug Fixes: ---------- * Fixed the installation issues with PyMySQL dependency on case-sensitive file systems. +[Amjith Ramanujam]: https://blog.amjith.com +[Artem Bezsmertnyi]: https://github.com/mrdeathless +[Carlos Afonso]: https://github.com/afonsocarlos +[Casper Langemeijer]: https://github.com/langemeijer [Daniel West]: http://github.com/danieljwest +[Dick Marinus]: https://github.com/meeuw +[François Pietka]: https://github.com/fpietka +[Frederic Aoustin]: https://github.com/fraoustin +[Georgy Frolov]: https://github.com/pasenor [Irina Truong]: https://github.com/j-bennet +[Jonathan Slenders]: https://github.com/jonathanslenders [Kacper Kwapisz]: https://github.com/KKKas +[laixintao]: https://github.com/laixintao +[Lennart Weller]: https://github.com/lhw [Martijn Engler]: https://github.com/martijnengler [Matheus Rosa]: https://github.com/mdsrosa -[Shoma Suzuki]: https://github.com/shoma -[spacewander]: https://github.com/spacewander -[Thomas Roten]: https://github.com/tsroten -[Artem Bezsmertnyi]: https://github.com/mrdeathless [Mikhail Borisov]: https://github.com/borman -[Casper Langemeijer]: Casper Langemeijer -[Lennart Weller]: https://github.com/lhw +[mtorromeo]: https://github.com/mtorromeo +[mwcm]: https://github.com/mwcm [Phil Cohen]: https://github.com/phlipper +[Scrappy Soft]: https://github.com/scrappysoft +[Shoma Suzuki]: https://github.com/shoma +[spacewander]: https://github.com/spacewander [Terseus]: https://github.com/Terseus +[Thomas Roten]: https://github.com/tsroten [William GARCIA]: https://github.com/willgarcia -[Jonathan Slenders]: https://github.com/jonathanslenders -[Casper Langemeijer]: https://github.com/langemeijer -[Scrappy Soft]: https://github.com/scrappysoft -[Dick Marinus]: https://github.com/meeuw +[xeron]: https://github.com/xeron +[Zach DeCook]: https://zachdecook.com diff --git a/conftest.py b/conftest.py deleted file mode 100644 index d2cd1336..00000000 --- a/conftest.py +++ /dev/null @@ -1,6 +0,0 @@ -import sys -collect_ignore = [ - "setup.py", - "mycli/magic.py", - "mycli/packages/parseutils.py", -] diff --git a/create_deb.sh b/create_deb.sh deleted file mode 100755 index 8e75ebb3..00000000 --- a/create_deb.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/sh - -set -e - -make-deb -cd debian - -cat > postinst <<- EOM -#!/bin/bash - -echo "Setting up symlink to mycli" -ln -sf /usr/share/python/mycli/bin/mycli /usr/local/bin/mycli -EOM -echo "Created postinst file." - -cat > postrm <<- EOM -#!/bin/bash - -echo "Removing symlink to mycli" -rm /usr/local/bin/mycli -EOM -echo "Created postrm file." - -for f in * -do - echo "" >> $f; -done - -echo "INFO: debian folder is setup and ready." -echo "INFO: 1. Update the changelog with real changes." -echo "INFO: 2. Run:\n\tvagrant provision || vagrant up" diff --git a/debian/changelog b/debian/changelog deleted file mode 100644 index 5135bf68..00000000 --- a/debian/changelog +++ /dev/null @@ -1,74 +0,0 @@ -mycli (1.7.0) unstable; urgency=medium - - * Add stdin batch mode. (Thanks: Thomas Roten). - * Add warn/no-warn command-line options. (Thanks: Thomas Roten). - * Upgrade sqlparse dependency to 0.1.19. (Thanks: [Amjith Ramanujam]). - * Update features list in README.md. (Thanks: Matheus Rosa). - * Remove extra \n in features list in README.md. (Thanks: Matheus Rosa). - * Enable history search via . (Thanks: [Amjith Ramanujam]). - * Upgrade prompt_toolkit to 1.0.0. (Thanks: Jonathan Slenders) - - -- Casper Langemeijer Fri, 27 May 2016 12:03:31 +0200 - -mycli (1.6.0) unstable; urgency=medium - - * Change continuation prompt for multi-line mode to match default mysql. - * Add status command to match mysql's status command. (Thanks: Thomas Roten). - * Add SSL support for mycli. (Thanks: Artem Bezsmertnyi). - * Add auto-completion and highlight support for OFFSET keyword. (Thanks: Matheus Rosa). - * Add support for MYSQL_TEST_LOGIN_FILE env variable to specify alternate login file. (Thanks: Thomas Roten). - * Add support for --auto-vertical-output to automatically switch to vertical output if the output doesn't fit in the table format. - * Add support for system-wide config. Now /etc/myclirc will be honored. (Thanks: Thomas Roten). - * Add support for nopager and \n to turn off the pager. (Thanks: Thomas Roten). - * Add support for --local-infile command-line option. (Thanks: Thomas Roten). - * Remove -S from less option which was clobbering the scroll back in history. (Thanks: Thomas Roten). - * Make system command work with Python 3. (Thanks: Thomas Roten). - * Support \G terminator for \f queries. (Thanks: Terseus). - * Upgrade prompt_toolkit to 0.60. - * Add Python 3.5 to test environments. (Thanks: Thomas Roten). - * Remove license meta-data. (Thanks: Thomas Roten). - * Skip binary tests if PyMySQL version does not support it. (Thanks: Thomas Roten). - * Refactor pager handling. (Thanks: Thomas Roten) - * Capture warnings to log file. (Thanks: Mikhail Borisov). - * Make syntax_style a tiny bit more intuitive. (Thanks: Phil Cohen). - - -- Casper Langemeijer Fri, 27 May 2016 12:03:31 +0200 - -mycli (1.5.2) unstable; urgency=low - - * Protect against port number being None when no port is specified in command line. - * Cast the value of port read from my.cnf to int. - * Make a config option to enable `audit_log`. (Thanks: [Matheus Rosa]). - * Add support for reading .mylogin.cnf to get user credentials. (Thanks: [Thomas Roten]). - * Register the special command `prompt` with the `\R` as alias. (Thanks: [Matheus Rosa]). - * Perform completion refresh in a background thread. Now mycli can handle - * Add support for `system` command. (Thanks: [Matheus Rosa]). - * Caught and hexed binary fields in MySQL. (Thanks: [Daniel West]). - * Treat enter key as tab when the suggestion menu is open. (Thanks: [Matheus Rosa]) - * Add "delete" and "truncate" as destructive commands. (Thanks: [Martijn Engler]). - * Change \dt syntax to add an optional table name. (Thanks: [Shoma Suzuki]). - * Add TRANSACTION related keywords. - * Treat DESC and EXPLAIN as DESCRIBE. (Thanks: [spacewander]). - * Fix the removal of whitespace from table output. - * Add ability to make suggestions for compound join clauses. (Thanks: [Matheus Rosa]). - * Fix the incorrect reporting of command time. - * Add type validation for port argument. (Thanks [Matheus Rosa]) - * Make pycrypto optional and only install it in \*nix systems. (Thanks: [Iryna Cherniavska]). - * Add badge for PyPI version to README. (Thanks: [Shoma Suzuki]). - * Updated release script with a --dry-run and --confirm-steps option. (Thanks: [Iryna Cherniavska]). - * Adds support for PyMySQL 0.6.2 and above. This is useful for debian package builders. (Thanks: [Thomas Roten]). - * Disable click warning. - - -- Casper Langemeijer Sun, 15 Nov 2015 10:26:24 +0100 - -mycli (1.4.0) unstable; urgency=low - - * Add `source` command. This allows running sql statement from a file. - * Added a config option to make the warning before destructive commands optional. (Thanks: [Daniel West](https://github.com/danieljwest)) - * Add completion support for CHANGE TO and other master/slave commands. This is still preliminary and it will be enhanced in the future. - * Add custom styles to color the menus and toolbars. - * Upgrade prompt_toolkit to 0.46. (Thanks: [Jonathan Slenders](https://github.com/jonathanslenders)) - * Fix keyword completion after the `WHERE` clause. - * Add `\g` and `\G` as valid query terminators. Previously in multi-line mode ending a query with a `\G` wouldn't run the query. This is now fixed. - - -- Amjith Ramanujam Sun, 23 Aug 2015 20:14:45 +0000 diff --git a/debian/compat b/debian/compat deleted file mode 100644 index ec635144..00000000 --- a/debian/compat +++ /dev/null @@ -1 +0,0 @@ -9 diff --git a/debian/control b/debian/control deleted file mode 100644 index 41838326..00000000 --- a/debian/control +++ /dev/null @@ -1,13 +0,0 @@ -Source: mycli -Section: python -Priority: extra -Maintainer: Amjith Ramanujam -Build-Depends: debhelper (>= 9), python, dh-virtualenv (>= 0.7), python-setuptools, python-dev -Standards-Version: 3.9.5 - -Package: mycli -Architecture: any -Pre-Depends: dpkg (>= 1.16.1), python2.7-minimal, ${misc:Pre-Depends} -Depends: ${python:Depends}, ${misc:Depends} -Description: CLI for MySQL Database. With auto-completion and syntax highlighting. - CLI for MySQL Database. With auto-completion and syntax highlighting. diff --git a/debian/mycli.triggers b/debian/mycli.triggers deleted file mode 100644 index b0b1d218..00000000 --- a/debian/mycli.triggers +++ /dev/null @@ -1,8 +0,0 @@ -# Register interest in Python interpreter changes (Python 2 for now); and -# don't make the Python package dependent on the virtualenv package -# processing (noawait) -interest-noawait /usr/bin/python2.7 - -# Also provide a symbolic trigger for all dh-virtualenv packages -interest dh-virtualenv-interpreter-update - diff --git a/debian/postinst b/debian/postinst deleted file mode 100644 index 122ff285..00000000 --- a/debian/postinst +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -echo "Setting up symlink to mycli" -ln -sf /usr/share/python/mycli/bin/mycli /usr/local/bin/mycli - diff --git a/debian/postrm b/debian/postrm deleted file mode 100644 index 850d6e8d..00000000 --- a/debian/postrm +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -echo "Removing symlink to mycli" -rm /usr/local/bin/mycli - diff --git a/debian/rules b/debian/rules deleted file mode 100644 index 299e3091..00000000 --- a/debian/rules +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/make -f - -%: - dh $@ --with python-virtualenv diff --git a/AUTHORS b/mycli/AUTHORS similarity index 51% rename from AUTHORS rename to mycli/AUTHORS index 2eb4af10..d1f3a280 100644 --- a/AUTHORS +++ b/mycli/AUTHORS @@ -6,7 +6,7 @@ Project Lead: Core Developers: ---------------- - * Iryna Cherniavska + * Irina Truong * Matheus Rosa * Darik Gamble * Dick Marinus @@ -15,42 +15,81 @@ Core Developers: Contributors: ------------- - * Steve Robbins - * Shoma Suzuki - * Daniel West - * Scrappy Soft - * Daniel Black - * Jonathan Bruno - * Casper Langemeijer - * Jonathan Slenders + * 0xflotus + * Abirami P + * Adam Chainz + * Aljosha Papsch + * Andy Teijelo Pérez + * Angelo Lupo * Artem Bezsmertnyi - * Mikhail Borisov + * bitkeen + * bjarnagin + * caitinggui + * Carlos Afonso + * Casper Langemeijer + * chainkite + * Colin Caine + * cxbig + * Daniel Black + * Daniel West + * Daniël van Eeden + * François Pietka + * Frederic Aoustin + * Georgy Frolov * Heath Naylor - * Phil Cohen - * spacewander - * Adam Chainz + * Huachao Mao + * Jakub Boukal + * jbruno + * Jerome Provensal + * Jialong Liu * Johannes Hoff + * John Sterling + * Jonathan Bruno + * Jonathan Lloyd + * Jonathan Slenders * Kacper Kwapisz + * Karthikeyan Singaravelan + * kevinhwang91 + * KITAGAWA Yasutaka + * Klaus Wünschel + * laixintao * Lennart Weller * Martijn Engler + * Massimiliano Torromeo + * Michał Górny + * Mike Palandra + * Mikhail Borisov + * Morgan Mitchell + * mrdeathless + * Nathan Huang + * Nicolas Palumbo + * Phil Cohen + * QiaoHou Peng + * Roland Walker + * Ryan Smith + * Scrappy Soft + * Seamile + * Shoma Suzuki + * spacewander + * Steve Robbins + * Takeshi D. Itoh + * Terje Røsten * Terseus * Tyler Kuipers + * ushuz * William GARCIA + * xeron + * Yang Zou * Yasuhiro Matsumoto - * bjarnagin - * jbruno - * mrdeathless - * Abirami P - * John Sterling - * Jialong Liu - * Zhidong - * Daniël van Eeden + * Zach DeCook + * Zane C. Bowers-Hadley * zer09 - * cxbig - * chainkite - * Michał Górny + * Zhaolong Zhu + * Zhidong + * Zhongyang Guan + * Arvind Mishra -Creator: --------- +Created by: +----------- Amjith Ramanujam diff --git a/SPONSORS b/mycli/SPONSORS similarity index 100% rename from SPONSORS rename to mycli/SPONSORS diff --git a/mycli/__init__.py b/mycli/__init__.py index 52af183e..e10d6ee2 100644 --- a/mycli/__init__.py +++ b/mycli/__init__.py @@ -1 +1 @@ -__version__ = '1.10.0' +__version__ = '1.24.4' diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index 41a63df1..81353b63 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -1,17 +1,20 @@ -from prompt_toolkit.buffer import Buffer +from prompt_toolkit.enums import DEFAULT_BUFFER from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app +from .packages import special -class CLIBuffer(Buffer): - def __init__(self, always_multiline, *args, **kwargs): - self.always_multiline = always_multiline - @Condition - def is_multiline(): - doc = self.document - return self.always_multiline and not _multiline_exception(doc.text) +def cli_is_multiline(mycli): + @Condition + def cond(): + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + + if not mycli.multi_line: + return False + else: + return not _multiline_exception(doc.text) + return cond - super(self.__class__, self).__init__(*args, is_multiline=is_multiline, - tempfile_suffix='.sql', **kwargs) def _multiline_exception(text): orig = text @@ -23,12 +26,30 @@ def _multiline_exception(text): if text.startswith('\\fs'): return orig.endswith('\n') - return (text.startswith('\\') or # Special Command - text.endswith(';') or # Ended with a semi-colon - text.endswith('\\g') or # Ended with \g - text.endswith('\\G') or # Ended with \G - (text == 'exit') or # Exit doesn't need semi-colon - (text == 'quit') or # Quit doesn't need semi-colon - (text == ':q') or # To all the vim fans out there - (text == '') # Just a plain enter without any text - ) + return ( + # Special Command + text.startswith('\\') or + + # Delimiter declaration + text.lower().startswith('delimiter') or + + # Ended with the current delimiter (usually a semi-column) + text.endswith(special.get_current_delimiter()) or + + text.endswith('\\g') or + text.endswith('\\G') or + text.endswith(r'\e') or + text.endswith(r'\clip') or + + # Exit doesn't need semi-column` + (text == 'exit') or + + # Quit doesn't need semi-column + (text == 'quit') or + + # To all teh vim fans out there + (text == ':q') or + + # just a plain enter without any text + (text == '') + ) diff --git a/mycli/clistyle.py b/mycli/clistyle.py index ef7c1c9d..b0ac9922 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -1,7 +1,94 @@ -from pygments.token import string_to_tokentype -from pygments.util import ClassNotFound -from prompt_toolkit.styles import default_style_extensions, style_from_dict +import logging + import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current', + Token.Menu.Completions.Completion: 'completion-menu.completion', + Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current', + Token.Menu.Completions.Meta: 'completion-menu.meta.completion', + Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta', + Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess + Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess + Token.SelectedText: 'selected', + Token.SearchMatch: 'search', + Token.SearchMatch.Current: 'search.current', + Token.Toolbar: 'bottom-toolbar', + Token.Toolbar.Off: 'bottom-toolbar.off', + Token.Toolbar.On: 'bottom-toolbar.on', + Token.Toolbar.Search: 'search-toolbar', + Token.Toolbar.Search.Text: 'search-toolbar.text', + Token.Toolbar.System: 'system-toolbar', + Token.Toolbar.Arg: 'arg-toolbar', + Token.Toolbar.Arg.Text: 'arg-toolbar.text', + Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid', + Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed', + Token.Output.Header: 'output.header', + Token.Output.OddRow: 'output.odd-row', + Token.Output.EvenRow: 'output.even-row', + Token.Output.Null: 'output.null', + Token.Prompt: 'prompt', + Token.Continuation: 'continuation', +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = { + v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() +} + +# all tokens that the Pygments MySQL lexer can produce +OVERRIDE_STYLE_TO_TOKEN = { + 'sql.comment': Token.Comment, + 'sql.comment.multi-line': Token.Comment.Multiline, + 'sql.comment.single-line': Token.Comment.Single, + 'sql.comment.optimizer-hint': Token.Comment.Special, + 'sql.escape': Token.Error, + 'sql.keyword': Token.Keyword, + 'sql.datatype': Token.Keyword.Type, + 'sql.literal': Token.Literal, + 'sql.literal.date': Token.Literal.Date, + 'sql.symbol': Token.Name, + 'sql.quoted-schema-object': Token.Name.Quoted, + 'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape, + 'sql.constant': Token.Name.Constant, + 'sql.function': Token.Name.Function, + 'sql.variable': Token.Name.Variable, + 'sql.number': Token.Number, + 'sql.number.binary': Token.Number.Bin, + 'sql.number.float': Token.Number.Float, + 'sql.number.hex': Token.Number.Hex, + 'sql.number.integer': Token.Number.Integer, + 'sql.operator': Token.Operator, + 'sql.punctuation': Token.Punctuation, + 'sql.string': Token.String, + 'sql.string.double-quouted': Token.String.Double, + 'sql.string.escape': Token.String.Escape, + 'sql.string.single-quoted': Token.String.Single, + 'sql.whitespace': Token.Text, +} + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError as err: + return token_type, style_dict[token_name] def style_factory(name, cli_style): @@ -10,10 +97,56 @@ def style_factory(name, cli_style): except ClassNotFound: style = pygments.styles.get_style_by_name('native') - styles = {} - styles.update(style.styles) - styles.update(default_style_extensions) - custom_styles = {string_to_tokentype(x): y for x, y in cli_style.items()} - styles.update(custom_styles) + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith('Token.'): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style( + token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error('Unhandled style / class name: %s', token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([('bottom-toolbar', 'noreverse')]) + return merge_styles([ + style_from_pygments_cls(style), + override_style, + Style(prompt_styles) + ]) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name('native').styles + + for token in cli_style: + if token.startswith('Token.'): + token_type, style_value = parse_pygments_style( + token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + elif token in OVERRIDE_STYLE_TO_TOKEN: + token_type = OVERRIDE_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error('Unhandled style / class name: %s', token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style - return style_from_dict(styles) + return OutputStyle diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py index b62d8edb..eec2978f 100644 --- a/mycli/clitoolbar.py +++ b/mycli/clitoolbar.py @@ -1,37 +1,53 @@ -from pygments.token import Token -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.application import get_app +from prompt_toolkit.enums import EditingMode +from .packages import special -def create_toolbar_tokens_func(get_is_refreshing): - """ - Return a function that generates the toolbar tokens. - """ - token = Token.Toolbar - def get_toolbar_tokens(cli): +def create_toolbar_tokens_func(mycli, show_fish_help): + """Return a function that generates the toolbar tokens.""" + def get_toolbar_tokens(): result = [] - result.append((token, ' ')) - - if cli.buffers[DEFAULT_BUFFER].completer.smart_completion: - result.append((token.On, '[F2] Smart Completion: ON ')) - else: - result.append((token.Off, '[F2] Smart Completion: OFF ')) - - if cli.buffers[DEFAULT_BUFFER].always_multiline: - result.append((token.On, '[F3] Multiline: ON ')) + result.append(('class:bottom-toolbar', ' ')) + + if mycli.multi_line: + delimiter = special.get_current_delimiter() + result.append( + ( + 'class:bottom-toolbar', + ' ({} [{}] will end the line) '.format( + 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter) + )) + + if mycli.multi_line: + result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) else: - result.append((token.Off, '[F3] Multiline: OFF ')) - - if cli.buffers[DEFAULT_BUFFER].always_multiline: - result.append((token, - ' (Semi-colon [;] will end the line)')) - - if cli.editing_mode == EditingMode.VI: - result.append((token.On, '[F4] Vi-mode')) - else: - result.append((token.On, '[F4] Emacs-mode')) - - if get_is_refreshing(): - result.append((token, ' Refreshing completions...')) + result.append(('class:bottom-toolbar.off', + '[F3] Multiline: OFF ')) + if mycli.prompt_app.editing_mode == EditingMode.VI: + result.append(( + 'class:botton-toolbar.on', + 'Vi-mode ({})'.format(_get_vi_mode()) + )) + + if show_fish_help(): + result.append( + ('class:bottom-toolbar', ' Right-arrow to complete suggestion')) + + if mycli.completion_refresher.is_refreshing(): + result.append( + ('class:bottom-toolbar', ' Refreshing completions...')) return result return get_toolbar_tokens + + +def _get_vi_mode(): + """Get the current vi mode for display.""" + return { + InputMode.INSERT: 'I', + InputMode.NAVIGATION: 'N', + InputMode.REPLACE: 'R', + InputMode.REPLACE_SINGLE: 'R', + InputMode.INSERT_MULTIPLE: 'M', + }[get_app().vi_state.input_mode] diff --git a/mycli/compat.py b/mycli/compat.py new file mode 100644 index 00000000..2ebfe07f --- /dev/null +++ b/mycli/compat.py @@ -0,0 +1,6 @@ +"""Platform and Python version compatibility support.""" + +import sys + + +WIN = sys.platform in ('win32', 'cygwin') diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index 2bbe32d0..124068a9 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -13,7 +13,7 @@ def __init__(self): self._completer_thread = None self._restart_refresh = threading.Event() - def refresh(self, executor, callbacks, completer_options={}): + def refresh(self, executor, callbacks, completer_options=None): """Creates a SQLCompleter object and populates it with the relevant completion suggestions in a background thread. @@ -25,6 +25,9 @@ def refresh(self, executor, callbacks, completer_options={}): completer_options - dict of options to pass to SQLCompleter. """ + if completer_options is None: + completer_options = {} + if self.is_refreshing(): self._restart_refresh.set() return [(None, None, None, 'Auto-completion refresh restarted.')] @@ -33,7 +36,7 @@ def refresh(self, executor, callbacks, completer_options={}): target=self._bg_refresh, args=(executor, callbacks, completer_options), name='completion_refresh') - self._completer_thread.setDaemon(True) + self._completer_thread.daemon = True self._completer_thread.start() return [(None, None, None, 'Auto-completion refresh started in the background.')] @@ -47,7 +50,9 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): # Create a new pgexecute method to popoulate the completions. e = sqlexecute executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, - e.socket, e.charset, e.local_infile, e.ssl) + e.socket, e.charset, e.local_infile, e.ssl, + e.ssh_user, e.ssh_host, e.ssh_port, + e.ssh_password, e.ssh_key_filename) # If callbacks is a single function then push it into a list. if callable(callbacks): diff --git a/mycli/config.py b/mycli/config.py index 7f5e0cb2..5d711093 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,21 +1,30 @@ -from __future__ import print_function -import shutil +from copy import copy from io import BytesIO, TextIOWrapper import logging import os from os.path import exists import struct import sys +from typing import Union, IO + from configobj import ConfigObj, ConfigObjError +import pyaes + +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources + try: basestring except NameError: basestring = str -from Crypto.Cipher import AES logger = logging.getLogger(__name__) + def log(logger, level, message): """Logs message to stderr if logging isn't initialized.""" @@ -24,18 +33,28 @@ def log(logger, level, message): else: print(message, file=sys.stderr) -def read_config_file(f): - """Read a config file.""" + +def read_config_file(f, list_values=True): + """Read a config file. + + *list_values* set to `True` is the default behavior of ConfigObj. + Disabling it causes values to not be parsed for lists, + (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are + not unquoted. We are disabling list_values when reading MySQL config files + so we can correctly interpret commas in passwords. + + """ if isinstance(f, basestring): f = os.path.expanduser(f) try: - config = ConfigObj(f, interpolation=False, encoding='utf8') + config = ConfigObj(f, interpolation=False, encoding='utf8', + list_values=list_values) except ConfigObjError as e: - log(logger, logging.ERROR, "Unable to parse line {0} of config file " + log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f)) - log(logger, logging.ERROR, "Using successfully parsed config values.") + log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: log(logger, logging.WARNING, "You don't have permission to read " @@ -44,25 +63,74 @@ def read_config_file(f): return config -def read_config_files(files): + +def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: + """Get a list of configuration files that are included into config_path + with !includedir directive. + + "Normal" configs should be passed as file paths. The only exception + is .mylogin which is decoded into a stream. However, it never + contains include directives and so will be ignored by this + function. + + """ + if not isinstance(config_file, str) or not os.path.isfile(config_file): + return [] + included_configs = [] + + try: + with open(config_file) as f: + include_directives = filter( + lambda s: s.startswith('!includedir'), + f + ) + dirs = map(lambda s: s.strip().split()[-1], include_directives) + dirs = filter(os.path.isdir, dirs) + for dir in dirs: + for filename in os.listdir(dir): + if filename.endswith('.cnf'): + included_configs.append(os.path.join(dir, filename)) + except (PermissionError, UnicodeDecodeError): + pass + return included_configs + + +def read_config_files(files, list_values=True): """Read and merge a list of config files.""" - config = ConfigObj() + config = create_default_config(list_values=list_values) + _files = copy(files) + while _files: + _file = _files.pop(0) + _config = read_config_file(_file, list_values=list_values) - for _file in files: - _config = read_config_file(_file) + # expand includes only if we were able to parse config + # (otherwise we'll just encounter the same errors again) + if config is not None: + _files = get_included_configs(_file) + _files if bool(_config) is True: config.merge(_config) config.filename = _config.filename return config -def write_default_config(source, destination, overwrite=False): + +def create_default_config(list_values=True): + import mycli + default_config_file = resources.open_text(mycli, 'myclirc') + return read_config_file(default_config_file, list_values=list_values) + + +def write_default_config(destination, overwrite=False): + import mycli + default_config = resources.read_text(mycli, 'myclirc') destination = os.path.expanduser(destination) if not overwrite and exists(destination): return - shutil.copyfile(source, destination) + with open(destination, 'w') as f: + f.write(default_config) + def get_mylogin_cnf_path(): """Return the path to the login path file or None if it doesn't exist.""" @@ -80,6 +148,7 @@ def get_mylogin_cnf_path(): return mylogin_cnf_path return None + def open_mylogin_cnf(name): """Open a readable version of .mylogin.cnf. @@ -92,7 +161,7 @@ def open_mylogin_cnf(name): try: with open(name, 'rb') as f: plaintext = read_and_decrypt_mylogin_cnf(f) - except (OSError, IOError): + except (OSError, IOError, ValueError): logger.error('Unable to open login path file.') return None @@ -102,6 +171,59 @@ def open_mylogin_cnf(name): return TextIOWrapper(plaintext) + +# TODO reuse code between encryption an decryption +def encrypt_mylogin_cnf(plaintext: IO[str]): + """Encryption of .mylogin.cnf file, analogous to calling + mysql_config_editor. + + Code is based on the python implementation by Kristian Koehntopp + https://github.com/isotopp/mysql-config-coder + + """ + def realkey(key): + """Create the AES key from the login key.""" + rkey = bytearray(16) + for i in range(len(key)): + rkey[i % 16] ^= key[i] + return bytes(rkey) + + def encode_line(plaintext, real_key, buf_len): + aes = pyaes.AESModeOfOperationECB(real_key) + text_len = len(plaintext) + pad_len = buf_len - text_len + pad_chr = bytes(chr(pad_len), "utf8") + plaintext = plaintext.encode() + pad_chr * pad_len + encrypted_text = b''.join( + [aes.encrypt(plaintext[i: i + 16]) + for i in range(0, len(plaintext), 16)] + ) + return encrypted_text + + LOGIN_KEY_LENGTH = 20 + key = os.urandom(LOGIN_KEY_LENGTH) + real_key = realkey(key) + + outfile = BytesIO() + + outfile.write(struct.pack("i", 0)) + outfile.write(key) + + while True: + line = plaintext.readline() + if not line: + break + real_len = len(line) + pad_len = (int(real_len / 16) + 1) * 16 + + outfile.write(struct.pack("i", pad_len)) + x = encode_line(line, real_key, pad_len) + outfile.write(x) + + outfile.seek(0) + return outfile + + def read_and_decrypt_mylogin_cnf(f): """Read and decrypt the contents of .mylogin.cnf. @@ -143,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f): return None rkey = struct.pack('16B', *rkey) - # Create a cipher object using the key. - aes_cipher = AES.new(rkey, AES.MODE_ECB) - # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() + aes = pyaes.AESModeOfOperationECB(rkey) while True: # Read the length of the ciphertext. @@ -158,24 +278,12 @@ def read_and_decrypt_mylogin_cnf(f): # Read cipher_len bytes from the file and decrypt. cipher = f.read(cipher_len) - pplain = aes_cipher.decrypt(cipher) - - try: - # Determine pad length. - pad_len = ord(pplain[-1:]) - except TypeError: - # ord() was unable to get the value of the byte. - logger.warning('Unable to remove pad.') + plain = _remove_pad( + b''.join([aes.decrypt(cipher[i: i + 16]) + for i in range(0, cipher_len, 16)]) + ) + if plain is False: continue - - if pad_len > len(pplain) or len(set(pplain[-pad_len:])) != 1: - # Pad length should be less than or equal to the length of the - # plaintext. The pad should have a single unqiue byte. - logger.warning('Invalid pad found in login path file.') - continue - - # Get rid of pad. - plain = pplain[:-pad_len] plaintext.write(plain) if plaintext.tell() == 0: @@ -185,6 +293,7 @@ def read_and_decrypt_mylogin_cnf(f): plaintext.seek(0) return plaintext + def str_to_bool(s): """Convert a string value to its corresponding boolean value.""" if isinstance(s, bool): @@ -200,4 +309,36 @@ def str_to_bool(s): elif s.lower() in false_values: return False else: - raise ValueError('not a recognized boolean value: %s'.format(s)) + raise ValueError('not a recognized boolean value: {0}'.format(s)) + + +def strip_matching_quotes(s): + """Remove matching, surrounding quotes from a string. + + This is the same logic that ConfigObj uses when parsing config + values. + + """ + if (isinstance(s, basestring) and len(s) >= 2 and + s[0] == s[-1] and s[0] in ('"', "'")): + s = s[1:-1] + return s + + +def _remove_pad(line): + """Remove the pad from the *line*.""" + try: + # Determine pad length. + pad_length = ord(line[-1:]) + except TypeError: + # ord() was unable to get the value of the byte. + logger.warning('Unable to remove pad.') + return False + + if pad_length > len(line) or len(set(line[-pad_length:])) != 1: + # Pad length should be less than or equal to the length of the + # plaintext. The pad should have a single unique byte. + logger.warning('Invalid pad found in login path file.') + return False + + return line[:-pad_length] diff --git a/mycli/encodingutils.py b/mycli/encodingutils.py deleted file mode 100644 index 1a8b5bbb..00000000 --- a/mycli/encodingutils.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -import binascii -import sys - -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -if PY2: - text_type = unicode - binary_type = str -else: - text_type = str - binary_type = bytes - - -def unicode2utf8(arg): - """Convert strings to UTF8-encoded bytes. - - Only in Python 2. In Python 3 the args are expected as unicode. - - """ - - if PY2 and isinstance(arg, text_type): - return arg.encode('utf-8') - return arg - - -def utf8tounicode(arg): - """Convert UTF8-encoded bytes to strings. - - Only in Python 2. In Python 3 the errors are returned as strings. - - """ - - if PY2 and isinstance(arg, binary_type): - return arg.decode('utf-8') - return arg - - -def bytes_to_string(b): - """Convert bytes to a string. Hexlify bytes that can't be decoded. - - >>> print(bytes_to_string(b"\\xff")) - 0xff - >>> print(bytes_to_string('abc')) - abc - >>> print(bytes_to_string('✌')) - ✌ - - """ - if isinstance(b, binary_type): - try: - return b.decode('utf8') - except UnicodeDecodeError: - return '0x' + binascii.hexlify(b).decode('ascii') - return b diff --git a/mycli/filters.py b/mycli/filters.py deleted file mode 100644 index 6a8075ff..00000000 --- a/mycli/filters.py +++ /dev/null @@ -1,12 +0,0 @@ -from prompt_toolkit.filters import Filter - -class HasSelectedCompletion(Filter): - """Enable when the current buffer has a selected completion.""" - - def __call__(self, cli): - complete_state = cli.current_buffer.complete_state - return (complete_state is not None and - complete_state.current_completion is not None) - - def __repr__(self): - return "HasSelectedCompletion()" diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py index 1651347e..4a24c82b 100644 --- a/mycli/key_bindings.py +++ b/mycli/key_bindings.py @@ -1,65 +1,49 @@ import logging from prompt_toolkit.enums import EditingMode -from prompt_toolkit.keys import Keys -from prompt_toolkit.key_binding.manager import KeyBindingManager -from prompt_toolkit.filters import Condition -from .filters import HasSelectedCompletion +from prompt_toolkit.filters import completion_is_selected +from prompt_toolkit.key_binding import KeyBindings _logger = logging.getLogger(__name__) -def mycli_bindings(): - """ - Custom key bindings for mycli. - """ - key_binding_manager = KeyBindingManager( - enable_open_in_editor=True, - enable_system_bindings=True, - enable_search=True, - enable_abort_and_exit_bindings=True) +def mycli_bindings(mycli): + """Custom key bindings for mycli.""" + kb = KeyBindings() - @key_binding_manager.registry.add_binding(Keys.F2) + @kb.add('f2') def _(event): - """ - Enable/Disable SmartCompletion Mode. - """ + """Enable/Disable SmartCompletion Mode.""" _logger.debug('Detected F2 key.') - buf = event.cli.current_buffer - buf.completer.smart_completion = not buf.completer.smart_completion + mycli.completer.smart_completion = not mycli.completer.smart_completion - @key_binding_manager.registry.add_binding(Keys.F3) + @kb.add('f3') def _(event): - """ - Enable/Disable Multiline Mode. - """ + """Enable/Disable Multiline Mode.""" _logger.debug('Detected F3 key.') - buf = event.cli.current_buffer - buf.always_multiline = not buf.always_multiline + mycli.multi_line = not mycli.multi_line - @key_binding_manager.registry.add_binding(Keys.F4) + @kb.add('f4') def _(event): - """ - Toggle between Vi and Emacs mode. - """ + """Toggle between Vi and Emacs mode.""" _logger.debug('Detected F4 key.') - if event.cli.editing_mode == EditingMode.VI: - event.cli.editing_mode = EditingMode.EMACS + if mycli.key_bindings == "vi": + event.app.editing_mode = EditingMode.EMACS + mycli.key_bindings = "emacs" else: - event.cli.editing_mode = EditingMode.VI + event.app.editing_mode = EditingMode.VI + mycli.key_bindings = "vi" - @key_binding_manager.registry.add_binding(Keys.Tab) + @kb.add('tab') def _(event): - """ - Force autocompletion at cursor. - """ + """Force autocompletion at cursor.""" _logger.debug('Detected key.') - b = event.cli.current_buffer + b = event.app.current_buffer if b.complete_state: b.complete_next() else: - event.cli.start_completion(select_first=True) + b.start_completion(select_first=True) - @key_binding_manager.registry.add_binding(Keys.ControlSpace) + @kb.add('c-space') def _(event): """ Initialize autocompletion at cursor. @@ -71,21 +55,35 @@ def _(event): """ _logger.debug('Detected key.') - b = event.cli.current_buffer + b = event.app.current_buffer if b.complete_state: b.complete_next() else: - event.cli.start_completion(select_first=False) + b.start_completion(select_first=False) - @key_binding_manager.registry.add_binding(Keys.ControlJ, filter=HasSelectedCompletion()) + @kb.add('enter', filter=completion_is_selected) def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + """ - Makes the enter key work as the tab key only when showing the menu. - """ - _logger.debug('Detected key.') + _logger.debug('Detected enter key.') event.current_buffer.complete_state = None - b = event.cli.current_buffer + b = event.app.current_buffer b.complete_state = None - return key_binding_manager + @kb.add('escape', 'enter') + def _(event): + """Introduces a line break in multi-line mode, or dispatches the + command in single-line mode.""" + _logger.debug('Detected alt-enter key.') + if mycli.multi_line: + event.app.current_buffer.validate_and_handle() + else: + event.app.current_buffer.insert_text('\n') + + return kb diff --git a/mycli/magic.py b/mycli/magic.py index 5527f72d..aad229a5 100644 --- a/mycli/magic.py +++ b/mycli/magic.py @@ -19,7 +19,7 @@ def load_ipython_extension(ipython): def mycli_line_magic(line): _logger.debug('mycli magic called: %r', line) parsed = sql.parse.parse(line, {}) - conn = sql.connection.Connection.get(parsed['connection']) + conn = sql.connection.Connection(parsed['connection']) try: # A corresponding mycli object already exists @@ -30,7 +30,7 @@ def mycli_line_magic(line): u = conn.session.engine.url _logger.debug('New mycli: %r', str(u)) - mycli.connect(u.database, u.host, u.username, u.port, u.password) + mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None) conn._mycli = mycli # For convenience, print the connection alias diff --git a/mycli/main.py b/mycli/main.py index ff3e8d9c..c13ed780 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,67 +1,93 @@ -#!/usr/bin/env python -from __future__ import unicode_literals -from __future__ import print_function - +from collections import defaultdict +from io import open import os -import os.path import sys +import shutil import traceback import logging import threading +import re +import stat +import fileinput +from collections import namedtuple +try: + from pwd import getpwuid +except ImportError: + pass from time import time from datetime import datetime from random import choice -from io import open +from pymysql import OperationalError +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output import preprocessors +from cli_helpers.utils import strip_ansi import click import sqlparse -from prompt_toolkit import CommandLineInterface, Application, AbortAction -from prompt_toolkit.interface import AcceptAction +from mycli.packages.parseutils import is_dropping_database, is_destructive +from prompt_toolkit.completion import DynamicCompleter from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.shortcuts import create_prompt_layout, create_eventloop +from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document -from prompt_toolkit.filters import Always, HasFocus, IsDone +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, ConditionalProcessor) +from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.history import FileHistory -from pygments.token import Token +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from .packages.special.main import NO_QUERY -import mycli.packages.special as special +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.tabular_output import sql_format +from .packages import special +from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func -from .clistyle import style_factory -from .sqlexecute import SQLExecute -from .clibuffer import CLIBuffer +from .clistyle import style_factory, style_factory_output +from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED +from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher from .config import (write_default_config, get_mylogin_cnf_path, - open_mylogin_cnf, read_config_files, str_to_bool) + open_mylogin_cnf, read_config_files, str_to_bool, + strip_matching_quotes) from .key_bindings import mycli_bindings -from .output_formatter import output_formatter -from .encodingutils import utf8tounicode from .lexer import MyCliLexer -from .__init__ import __version__ +from . import __version__ +from .compat import WIN +from .packages.filepaths import dir_path_exists, guess_socket_location + +import itertools click.disable_unicode_literals_warning = True try: from urlparse import urlparse - FileNotFoundError = OSError + from urlparse import unquote except ImportError: from urllib.parse import urlparse -from pymysql import OperationalError + from urllib.parse import unquote -from collections import namedtuple +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) -PACKAGE_ROOT = os.path.dirname(__file__) +SUPPORT_INFO = ( + 'Home: http://mycli.net\n' + 'Bug tracker: https://github.com/dbcli/mycli/issues' +) -# no-op logging handler -class NullHandler(logging.Handler): - def emit(self, record): - pass class MyCli(object): @@ -74,15 +100,20 @@ class MyCli(object): '/etc/my.cnf', '/etc/mysql/my.cnf', '/usr/local/etc/my.cnf', - '~/.my.cnf' + os.path.expanduser('~/.my.cnf'), ] + # check XDG_CONFIG_HOME exists and not an empty string + if os.environ.get("XDG_CONFIG_HOME"): + xdg_config_home = os.environ.get("XDG_CONFIG_HOME") + else: + xdg_config_home = "~/.config" system_config_files = [ '/etc/myclirc', + os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") ] - default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') - + pwd_config_file = os.path.join(os.getcwd(), ".myclirc") def __init__(self, sqlexecute=None, prompt=None, logfile=None, defaults_suffix=None, defaults_file=None, @@ -101,17 +132,27 @@ def __init__(self, sqlexecute=None, prompt=None, self.cnf_files = [defaults_file] # Load config. - config_files = ([self.default_config_file] + self.system_config_files + - [myclirc]) + config_files = (self.system_config_files + + [myclirc] + [self.pwd_config_file]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) - self.formatter = output_formatter.OutputFormatter( + + FavoriteQueries.instance = FavoriteQueries.from_config(self.config) + + self.dsn_alias = None + self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) + sql_format.register_new_formatter(self.formatter) + self.formatter.mycli = self self.syntax_style = c['main']['syntax_style'] self.less_chatty = c['main'].as_bool('less_chatty') self.cli_style = c['colors'] + self.output_style = style_factory_output( + self.syntax_style, + self.cli_style + ) self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') c_dest_warning = c['main'].as_bool('destructive_warning') self.destructive_warning = c_dest_warning if warn is None else warn @@ -122,15 +163,16 @@ def __init__(self, sqlexecute=None, prompt=None, c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. - if c.filename not in self.system_config_files: - write_default_config(self.default_config_file, myclirc) + if c.filename not in self.system_config_files and not os.path.exists(myclirc): + write_default_config(myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: try: self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') except (IOError, OSError) as e: - self.output('Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red') + self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', + err=True, fg='red') self.logfile = False self.completion_refresher = CompletionRefresher() @@ -141,7 +183,8 @@ def __init__(self, sqlexecute=None, prompt=None, prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt - self.prompt_continuation_format = c['main']['prompt_continuation'] + self.multiline_continuation_char = c['main']['prompt_continuation'] + keyword_casing = c['main'].get('keyword_casing', 'auto') self.query_history = [] @@ -149,7 +192,8 @@ def __init__(self, sqlexecute=None, prompt=None, self.smart_completion = c['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, - supported_formats=self.formatter.supported_formats()) + supported_formats=self.formatter.supported_formats, + keyword_casing=keyword_casing) self._completer_lock = threading.Lock() # Register custom special commands. @@ -166,7 +210,7 @@ def __init__(self, sqlexecute=None, prompt=None, # There was an error reading the login path file. print('Error: Unable to read login path file.') - self.cli = None + self.prompt_app = None def register_special_commands(self): special.register_special_command(self.change_db, 'use', @@ -176,8 +220,10 @@ def register_special_commands(self): aliases=('\\r', ), case_sensitive=True) special.register_special_command(self.refresh_completions, 'rehash', '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) - special.register_special_command(self.change_table_format, 'tableformat', - '\\T', 'Change Table Type.', aliases=('\\T',), case_sensitive=True) + special.register_special_command( + self.change_table_format, 'tableformat', '\\T', + 'Change the table format used to output results.', + aliases=('\\T',), case_sensitive=True) special.register_special_command(self.execute_from_file, 'source', '\\. filename', 'Execute commands from file.', aliases=('\\.',)) special.register_special_command(self.change_prompt_format, 'prompt', @@ -185,21 +231,28 @@ def register_special_commands(self): def change_table_format(self, arg, **_): try: - self.formatter.set_format_name(arg) + self.formatter.format_name = arg yield (None, None, None, - 'Changed table type to {}'.format(arg)) + 'Changed table format to {}'.format(arg)) except ValueError: - msg = 'Table type {} not yet implemented. Allowed types:'.format( + msg = 'Table format {} not recognized. Allowed formats:'.format( arg) - for table_type in self.formatter.supported_formats(): + for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): - if arg is None: - self.sqlexecute.connect() - else: - self.sqlexecute.connect(database=arg) + if not arg: + click.secho( + "No database selected", + err=True, fg="red" + ) + return + + if arg.startswith('`') and arg.endswith('`'): + arg = re.sub(r'^`(.*)`$', r'\1', arg) + arg = re.sub(r'``', r'`', arg) + self.sqlexecute.change_db(arg) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) @@ -209,7 +262,7 @@ def execute_from_file(self, arg, **_): message = 'Missing required argument, filename.' return [(None, None, None, message)] try: - with open(os.path.expanduser(arg), encoding='utf-8') as f: + with open(os.path.expanduser(arg)) as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] @@ -234,7 +287,7 @@ def change_prompt_format(self, arg, **_): def initialize_logging(self): - log_file = self.config['main']['log_file'] + log_file = os.path.expanduser(self.config['main']['log_file']) log_level = self.config['main']['log_level'] level_map = {'CRITICAL': logging.CRITICAL, @@ -247,10 +300,15 @@ def initialize_logging(self): # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": - handler = NullHandler() + handler = logging.NullHandler() log_level = "CRITICAL" + elif dir_path_exists(log_file): + handler = logging.FileHandler(log_file) else: - handler = logging.FileHandler(os.path.expanduser(log_file)) + self.echo( + 'Error: Unable to open the log file "{}".'.format(log_file), + err=True, fg='red') + return formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' @@ -267,11 +325,6 @@ def initialize_logging(self): root_logger.debug('Initializing mycli logging.') root_logger.debug('Log file %r.', log_file) - def connect_uri(self, uri, local_infile=None, ssl=None): - uri = urlparse(uri) - database = uri.path[1:] # ignore the leading fwd slash - self.connect(database, uri.username, uri.password, uri.hostname, - uri.port, local_infile=local_infile, ssl=ssl) def read_my_cnf_files(self, files, keys): """ @@ -280,23 +333,36 @@ def read_my_cnf_files(self, files, keys): :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ - cnf = read_config_files(files) + cnf = read_config_files(files, list_values=False) + + sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + }, + } - sections = ['client'] if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) - def get(key): - result = None - for sect in cnf: - if sect in sections and key in cnf[sect]: - result = cnf[sect][key] - return result + configuration = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if ( + section not in sections or + key not in cnf[section] + ): + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes( + cnf[section][key]) + + return configuration - return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" @@ -322,7 +388,9 @@ def merge_ssl_with_cnf(self, ssl, cnf): return merged def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl=''): + socket='', charset='', local_infile='', ssl='', + ssh_user='', ssh_host='', ssh_port='', + ssh_password='', ssh_key_filename='', init_command='', password_file=''): cnf = {'database': None, 'user': None, @@ -330,6 +398,7 @@ def connect(self, database='', user='', passwd='', host='', port='', 'host': None, 'port': None, 'socket': None, + 'default_socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, @@ -343,25 +412,24 @@ def connect(self, database='', user='', passwd='', host='', port='', cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - if port or host: - socket = '' - else: - socket = socket or cnf['socket'] user = user or cnf['user'] or os.getenv('USER') - host = host or cnf['host'] or 'localhost' - port = port or cnf['port'] or 3306 + host = host or cnf['host'] + port = port or cnf['port'] ssl = ssl or {} - try: - port = int(port) - except ValueError as e: - self.output("Error: Invalid port number: '{0}'.".format(port), - err=True, fg='red') - exit(1) + port = port and int(port) + if not port: + port = 3306 + if not host or host == 'localhost': + socket = ( + cnf['socket'] or + cnf['default_socket'] or + guess_socket_location() + ) + - passwd = passwd or cnf['password'] + passwd = passwd if isinstance(passwd, str) else cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' # Favor whichever local_infile option is set. @@ -378,60 +446,142 @@ def connect(self, database='', user='', passwd='', host='', port='', if not any(v for v in ssl.values()): ssl = None + # if the passwd is not specfied try to set it using the password_file option + password_from_file = self.get_password_from_file(password_file) + passwd = passwd or password_from_file + # Connect to the database. - try: + def _connect(): try: - sqlexecute = SQLExecute(database, user, passwd, host, port, - socket, charset, local_infile, ssl) + self.sqlexecute = SQLExecute( + database, user, passwd, host, port, socket, charset, + local_infile, ssl, ssh_user, ssh_host, ssh_port, + ssh_password, ssh_key_filename, init_command + ) except OperationalError as e: - if ('Access denied for user' in e.args[1]): - passwd = click.prompt('Password', hide_input=True, - show_default=False, type=str) - sqlexecute = SQLExecute(database, user, passwd, host, port, - socket, charset, local_infile, ssl) + if e.args[0] == ERROR_CODE_ACCESS_DENIED: + if password_from_file: + new_passwd = password_from_file + else: + new_passwd = click.prompt('Password', hide_input=True, + show_default=False, type=str, err=True) + self.sqlexecute = SQLExecute( + database, user, new_passwd, host, port, socket, + charset, local_infile, ssl, ssh_user, ssh_host, + ssh_port, ssh_password, ssh_key_filename, init_command + ) else: raise e + + try: + if not WIN and socket: + socket_owner = getpwuid(os.stat(socket).st_uid).pw_name + self.echo( + f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) + try: + _connect() + except OperationalError as e: + # These are "Can't open socket" and 2x "Can't connect" + if [code for code in (2001, 2002, 2003) if code == e.args[0]]: + self.logger.debug('Database connection failed: %r.', e) + self.logger.error( + "traceback: %r", traceback.format_exc()) + self.logger.debug('Retrying over TCP/IP') + self.echo( + "Failed to connect to local MySQL server through socket '{}':".format(socket)) + self.echo(str(e), err=True) + self.echo( + 'Retrying over TCP/IP', err=True) + + # Else fall back to TCP/IP localhost + socket = "" + host = 'localhost' + port = 3306 + _connect() + else: + raise e + else: + host = host or 'localhost' + port = port or 3306 + + # Bad ports give particularly daft error messages + try: + port = int(port) + except ValueError as e: + self.echo("Error: Invalid port number: '{0}'.".format(port), + err=True, fg='red') + exit(1) + + _connect() except Exception as e: # Connecting to a database could fail. self.logger.debug('Database connection failed: %r.', e) self.logger.error("traceback: %r", traceback.format_exc()) - self.output(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg='red') exit(1) - self.sqlexecute = sqlexecute + def get_password_from_file(self, password_file): + password_from_file = None + if password_file: + if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ + and os.access(password_file, os.R_OK): + with open(password_file) as fp: + password_from_file = fp.readline() + password_from_file = password_from_file.rstrip().lstrip() + + return password_from_file + + def handle_editor_command(self, text): + r"""Editor command is any query that is prefixed or suffixed by a '\e'. + The reason for a while loop is because a user might edit a query + multiple times. For eg: - def handle_editor_command(self, cli, document): - """ - Editor command is any query that is prefixed or suffixed - by a '\e'. The reason for a while loop is because a user - might edit a query multiple times. - For eg: "select * from \e" to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. - :param cli: CommandLineInterface - :param document: Document + :param text: Document :return: Document + """ - # FIXME: using application.pre_run_callables like this here is not the best solution. - # It's internal api of prompt_toolkit that may change. This was added to fix - # https://github.com/dbcli/pgcli/issues/668. We may find a better way to do it in the future. - saved_callables = cli.application.pre_run_callables - while special.editor_command(document.text): - filename = special.get_filename(document.text) - sql, message = special.open_external_editor(filename, - sql=document.text) + + while special.editor_command(text): + filename = special.get_filename(text) + query = (special.get_editor_query(text) or + self.get_last_query()) + sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) - cli.current_buffer.document = Document(sql, cursor_position=len(sql)) - cli.application.pre_run_callables = [] - document = cli.run() + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + continue - cli.application.pre_run_callables = saved_callables - return document + return text + + def handle_clip_command(self, text): + r"""A clip command is any query that is prefixed or suffixed by a + '\clip'. + + :param text: Document + :return: Boolean + + """ + + if special.clip_command(text): + query = (special.get_clip_query(text) or + self.get_last_query()) + message = special.copy_query_to_clipboard(sql=query) + if message: + raise RuntimeError(message) + return True + return False def run_cli(self): + iterations = 0 sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() @@ -439,59 +589,87 @@ def run_cli(self): if self.smart_completion: self.refresh_completions() - project_root = os.path.dirname(PACKAGE_ROOT) - author_file = os.path.join(project_root, 'AUTHORS') - sponsor_file = os.path.join(project_root, 'SPONSORS') + history_file = os.path.expanduser( + os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) + if dir_path_exists(history_file): + history = FileHistory(history_file) + else: + history = None + self.echo( + 'Error: Unable to open the history file "{}". ' + 'Your query history will not be saved.'.format(history_file), + err=True, fg='red') - key_binding_manager = mycli_bindings() + key_bindings = mycli_bindings(self) if not self.less_chatty: - print('Version:', __version__) - print('Chat: https://gitter.im/dbcli/mycli') - print('Mail: https://groups.google.com/forum/#!forum/mycli-users') - print('Home: http://mycli.net') - print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) + print(sqlexecute.server_info) + print('mycli', __version__) + print(SUPPORT_INFO) + print('Thanks to the contributor -', thanks_picker()) - def prompt_tokens(cli): + def get_message(): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt('\\d> ') - return [(Token.Prompt, prompt)] + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, *_): + if self.multiline_continuation_char == '': + continuation = '' + elif self.multiline_continuation_char: + left_padding = width - len(self.multiline_continuation_char) + continuation = " " * \ + max((left_padding - 1), 0) + \ + self.multiline_continuation_char + " " + else: + continuation = " " + return [('class:continuation', continuation)] - def get_continuation_tokens(cli, width): - continuation_prompt = self.get_prompt(self.prompt_continuation_format) - return [(Token.Continuation, ' ' * (width - len(continuation_prompt)) + continuation_prompt)] + def show_suggestion_tip(): + return iterations < 2 - def one_iteration(document=None): - if document is None: - document = self.cli.run() + def one_iteration(text=None): + if text is None: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + return special.set_expanded_output(False) - # The reason we check here instead of inside the sqlexecute is - # because we want to raise the Exit exception which will be - # caught by the try/except block that wraps the - # sqlexecute.run() statement. - if quit_command(document.text): - raise EOFError + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + return try: - document = self.handle_editor_command(self.cli, document) + if self.handle_clip_command(text): + return except RuntimeError as e: - logger.error("sql: %r, error: %r", document.text, e) + logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.output(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg='red') return + if not text.strip(): + return + if self.destructive_warning: - destroy = confirm_destructive_query(document.text) + destroy = confirm_destructive_query(text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: - self.output('Your call!') + self.echo('Your call!') else: - self.output('Wise choice!') + self.echo('Wise choice!') return + else: + destroy = True # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered @@ -499,20 +677,20 @@ def one_iteration(document=None): mutating = False try: - logger.debug('sql: %r', document.text) + logger.debug('sql: %r', text) - special.write_tee(self.get_prompt(self.prompt_format) + document.text) + special.write_tee(self.get_prompt(self.prompt_format) + text) if self.logfile: self.logfile.write('\n# %s\n' % datetime.now()) - self.logfile.write(document.text) + self.logfile.write(text) self.logfile.write('\n') successful = False start = time() - res = sqlexecute.run(document.text) + res = sqlexecute.run(text) + self.formatter.query = text successful = True - output = [] - total = 0 + result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) @@ -520,25 +698,41 @@ def one_iteration(document=None): threshold = 1000 if (is_select(status) and cur and cur.rowcount > threshold): - self.output('The result set has more than %s rows.' - % threshold, fg='red') - if not click.confirm('Do you want to continue?'): - self.output("Aborted!", err=True, fg='red') + self.echo('The result set has more than {} rows.'.format( + threshold), fg='red') + if not confirm('Do you want to continue?'): + self.echo("Aborted!", err=True, fg='red') break if self.auto_vertical_output: - max_width = self.cli.output.get_size().columns + max_width = self.prompt_app.output.get_size().columns else: max_width = None - formatted = self.format_output(title, cur, headers, status, - special.is_expanded_output(), - max_width) + formatted = self.format_output( + title, cur, headers, special.is_expanded_output(), + max_width) - output.extend(formatted) - end = time() - total += end - start - mutating = mutating or is_mutating(status) + t = time() - start + try: + if result_count > 0: + self.echo('') + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + if special.is_timing_enabled(): + self.echo('Time: %0.03fs' % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or destroy or is_mutating(status) + special.unset_once_if_written() + special.unset_pipe_once_if_written() + except EOFError as e: + raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id @@ -550,110 +744,185 @@ def one_iteration(document=None): status_str = str(status).lower() if status_str.find('ok') > -1: logger.debug("cancelled query, connection id: %r, sql: %r", - connection_id_to_kill, document.text) - self.output("cancelled query", err=True, fg='red') + connection_id_to_kill, text) + self.echo("cancelled query", err=True, fg='red') except Exception as e: - self.output('Encountered error while cancelling query: %s' % str(e), err=True, fg='red') + self.echo('Encountered error while cancelling query: {}'.format(e), + err=True, fg='red') except NotImplementedError: - self.output('Not Yet Implemented.', fg="yellow") + self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) if (e.args[0] in (2003, 2006, 2013)): logger.debug('Attempting to reconnect.') - self.output('Reconnecting...', fg='yellow') + self.echo('Reconnecting...', fg='yellow') try: sqlexecute.connect() logger.debug('Reconnected successfully.') - one_iteration(document) + one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: logger.debug('Reconnect failed. e: %r', e) - self.output(str(e), err=True, fg='red') - return # If reconnection failed, don't proceed further. + self.echo(str(e), err=True, fg='red') + # If reconnection failed, don't proceed further. + return else: - logger.error("sql: %r, error: %r", document.text, e) + logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.output(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg='red') except Exception as e: - logger.error("sql: %r, error: %r", document.text, e) + logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.output(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg='red') else: - try: - special.write_tee('\n'.join(output)) - if special.is_pager_enabled(): - self.output_via_pager('\n'.join(output)) - else: - self.output('\n'.join(output)) - except KeyboardInterrupt: - pass - if special.is_timing_enabled(): - self.output('Time: %0.03fs' % total) + if is_dropping_database(text, self.sqlexecute.dbname): + self.sqlexecute.dbname = None + self.sqlexecute.connect() # Refresh the table names and column names if necessary. - if need_completion_refresh(document.text): + if need_completion_refresh(text): self.refresh_completions( - reset=need_completion_reset(document.text)) + reset=need_completion_reset(text)) finally: if self.logfile is False: - self.output("Warning: This query was not logged.", err=True, fg='red') - query = Query(document.text, successful, mutating) + self.echo("Warning: This query was not logged.", + err=True, fg='red') + query = Query(text, successful, mutating) self.query_history.append(query) + get_toolbar_tokens = create_toolbar_tokens_func( + self, show_suggestion_tip) + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN - get_toolbar_tokens = create_toolbar_tokens_func(self.completion_refresher.is_refreshing) - - layout = create_prompt_layout(lexer=MyCliLexer, - multiline=True, - get_prompt_tokens=prompt_tokens, - get_continuation_tokens=get_continuation_tokens, - get_bottom_toolbar_tokens=get_toolbar_tokens, - display_completions_in_columns=self.wider_completion_menu, - extra_input_processors=[ - ConditionalProcessor( - processor=HighlightMatchingBracketProcessor(chars='[](){}'), - filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()), - ]) with self._completer_lock: - buf = CLIBuffer(always_multiline=self.multi_line, completer=self.completer, - history=FileHistory(os.path.expanduser(os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))), - complete_while_typing=Always(), accept_action=AcceptAction.RETURN_DOCUMENT) if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS - application = Application(style=style_factory(self.syntax_style, self.cli_style), - layout=layout, buffer=buf, - key_bindings_registry=key_binding_manager.registry, - on_exit=AbortAction.RAISE_EXCEPTION, - on_abort=AbortAction.RETRY, - editing_mode=editing_mode, - ignore_case=True) - self.cli = CommandLineInterface(application=application, - eventloop=create_eventloop()) + self.prompt_app = PromptSession( + lexer=PygmentsLexer(MyCliLexer), + reserve_space_for_menu=self.get_reserved_space(), + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ConditionalProcessor( + processor=HighlightMatchingBracketProcessor( + chars='[](){}'), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() + )], + tempfile_suffix='.sql', + completer=DynamicCompleter(lambda: self.completer), + history=history, + auto_suggest=AutoSuggestFromHistory(), + complete_while_typing=True, + multiline=cli_is_multiline(self), + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True + ) try: while True: one_iteration() + iterations += 1 except EOFError: special.close_tee() if not self.less_chatty: - self.output('Goodbye!') + self.echo('Goodbye!') - def output(self, text, **kwargs): - special.write_tee(text) + def log_output(self, output): + """Log the output in the audit log, if it's enabled.""" if self.logfile: - self.logfile.write(utf8tounicode(text)) - self.logfile.write('\n') - click.secho(text, **kwargs) + click.echo(output, file=self.logfile) - def output_via_pager(self, text): - if self.logfile: - self.logfile.write(text) - self.logfile.write('\n') - click.echo_via_pager(text) + def echo(self, s, **kwargs): + """Print a message to stdout. + + The message will be logged in the audit log, if enabled. + + All keyword arguments are passed to click.echo(). + + """ + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status=None): + """Get the output margin (number of rows for the prompt, footer and + timing message.""" + margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 + if special.is_timing_enabled(): + margin += 1 + if status: + margin += 1 + status.count('\n') + + return margin + + + def output(self, output, status=None): + """Output text to stdout or a pager command. + + The status text is not outputted to pager or files. + + The message will be logged in the audit log, if enabled. The + message will be written to the tee file, if enabled. The + message will be written to the output file, if enabled. + + """ + if output: + size = self.prompt_app.output.get_size() + + margin = self.get_output_margin(status) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + special.write_pipe_once(line) + + if fits or output_via_pager: + # buffering + buf.append(line) + if len(line) > size.columns or i > (size.rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + # doesn't fit, use pager + output_via_pager = True + + if not output_via_pager: + # doesn't fit, flush buffer + for buf_line in buf: + click.secho(buf_line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + def newlinewrapper(text): + for line in text: + yield line + "\n" + click.echo_via_pager(newlinewrapper(buf)) + else: + for line in buf: + click.secho(line) + + if status: + self.log_output(status) + click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. @@ -663,7 +932,11 @@ def configure_pager(self): cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) if cnf['pager']: special.set_pager(cnf['pager']) - if cnf['skip-pager']: + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): special.disable_pager() def refresh_completions(self, reset=False): @@ -673,29 +946,22 @@ def refresh_completions(self, reset=False): self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, {'smart_completion': self.smart_completion, - 'supported_formats': self.formatter.supported_formats()}) + 'supported_formats': self.formatter.supported_formats, + 'keyword_casing': self.completer.keyword_casing}) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): - self._swap_completer_objects(new_completer) - - if self.cli: - # After refreshing, redraw the CLI to clear the statusbar - # "Refreshing completions..." indicator - self.cli.request_redraw() - - def _swap_completer_objects(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer - # When mycli is first launched we call refresh_completions before - # instantiating the cli object. So it is necessary to check if cli - # exists before trying the replace the completer object in cli. - if self.cli: - self.cli.current_buffer.completer = new_completer + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() def get_completions(self, text, cursor_positition): with self._completer_lock: @@ -705,13 +971,20 @@ def get_completions(self, text, cursor_positition): def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host + now = datetime.now() string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') + string = string.replace('\\t', sqlexecute.server_info.species.name) string = string.replace('\\n', "\n") - string = string.replace('\\D', datetime.now().strftime('%a %b %d %H:%M:%S %Y')) + string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) + string = string.replace('\\m', now.strftime('%M')) + string = string.replace('\\P', now.strftime('%p')) + string = string.replace('\\R', now.strftime('%H')) + string = string.replace('\\r', now.strftime('%I')) + string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\p', str(sqlexecute.port)) + string = string.replace('\\A', self.dsn_alias or '(none)') string = string.replace('\\_', ' ') return string @@ -720,71 +993,126 @@ def run_query(self, query, new_line=True): results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result - output = self.format_output(title, cur, headers, None) + self.formatter.query = query + output = self.format_output(title, cur, headers) for line in output: click.echo(line, nl=new_line) - def format_output(self, title, cur, headers, status, expanded=False, + def format_output(self, title, cur, headers, expanded=False, max_width=None): - expanded = expanded or self.formatter.get_format_name() == 'expanded' + expanded = expanded or self.formatter.format_name == 'vertical' output = [] + output_kwargs = { + 'dialect': 'unix', + 'disable_numparse': True, + 'preserve_whitespace': True, + 'style': self.output_style + } + + if not self.formatter.format_name in sql_format.supported_formats: + output_kwargs["preprocessors"] = (preprocessors.align_decimals, ) + if title: # Only print the title if it's not None. - output.append(title) + output = itertools.chain(output, [title]) if cur: - rows = list(cur) - formatted = self.formatter.format_output( - rows, headers, format_name='expanded' if expanded else None) + column_types = None + if hasattr(cur, 'description'): + def get_col_type(col): + col_type = FIELD_TYPES.get(col[1], str) + return col_type if type(col_type) is type else str + column_types = [get_col_type(col) for col in cur.description] - if (not expanded and max_width and rows and - content_exceeds_width(rows[0], max_width) and headers): - formatted = self.formatter.format_output( - rows, headers, format_name='expanded') + if max_width is not None: + cur = list(cur) - output.append(formatted) + formatted = self.formatter.format_output( + cur, headers, format_name='vertical' if expanded else None, + column_types=column_types, + **output_kwargs) + + if isinstance(formatted, str): + formatted = formatted.splitlines() + formatted = iter(formatted) + + if (not expanded and max_width and headers and cur): + first_line = next(formatted) + if len(strip_ansi(first_line)) > max_width: + formatted = self.formatter.format_output( + cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + else: + formatted = itertools.chain([first_line], formatted) + + output = itertools.chain(output, formatted) - if status: # Only print the status if it's not None. - output.append(status) return output + def get_reserved_space(self): + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = .45 + max_reserved_space = 8 + _, height = shutil.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + @click.command() @click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.') @click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors ' - '$MYSQL_TCP_PORT') + '$MYSQL_TCP_PORT.') @click.option('-u', '--user', help='User name to connect to the database.') @click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.') @click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database') + help='Password to connect to the database.') @click.option('--pass', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database') -@click.option('--ssl-ca', help='CA file in PEM format', + help='Password to connect to the database.') +@click.option('--ssh-user', help='User name to connect to ssh server.') +@click.option('--ssh-host', help='Host name to connect to ssh server.') +@click.option('--ssh-port', default=22, help='Port to connect to ssh server.') +@click.option('--ssh-password', help='Password to connect to ssh server.') +@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') +@click.option('--ssh-config-path', help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config') +@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') +@click.option('--ssl-ca', help='CA file in PEM format.', type=click.Path(exists=True)) -@click.option('--ssl-capath', help='CA directory') -@click.option('--ssl-cert', help='X509 cert in PEM format', +@click.option('--ssl-capath', help='CA directory.') +@click.option('--ssl-cert', help='X509 cert in PEM format.', type=click.Path(exists=True)) -@click.option('--ssl-key', help='X509 key in PEM format', +@click.option('--ssl-key', help='X509 key in PEM format.', type=click.Path(exists=True)) -@click.option('--ssl-cipher', help='SSL cipher to use') +@click.option('--ssl-cipher', help='SSL cipher to use.') @click.option('--ssl-verify-server-cert', is_flag=True, help=('Verify server\'s "Common Name" in its cert against ' 'hostname used when connecting. This option is disabled ' - 'by default')) + 'by default.')) # as of 2016-02-15 revocation list is not supported by underling PyMySQL # library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) -@click.option('-v', '--version', is_flag=True, help='Version of mycli.') +@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.') +@click.option('-v', '--verbose', is_flag=True, help='Verbose output.') @click.option('-D', '--database', 'dbname', help='Database to use.') +@click.option('-d', '--dsn', default='', envvar='DSN', + help='Use DSN configured into the [alias_dsn] section of myclirc file.') +@click.option('--list-dsn', 'list_dsn', is_flag=True, + help='list of DSN configured into the [alias_dsn] section of myclirc file.') +@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).') @click.option('-R', '--prompt', 'prompt', - help='Prompt format (Default: "{0}")'.format( + help='Prompt format (Default: "{0}").'.format( MyCli.default_prompt)) @click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'), help='Log every query and its results to a file.') @click.option('--defaults-group-suffix', type=str, - help='Read config group with the specified suffix.') + help='Read MySQL config groups with the specified suffix.') @click.option('--defaults-file', type=click.Path(), - help='Only read default options from the given file') + help='Only read MySQL options from the given file.') @click.option('--myclirc', type=click.Path(), default="~/.myclirc", help='Location of myclirc file.') @click.option('--auto-vertical-output', is_flag=True, @@ -797,16 +1125,34 @@ def format_output(self, title, cur, headers, status, expanded=False, help='Warn before running a destructive query.') @click.option('--local-infile', type=bool, help='Enable/disable LOAD DATA LOCAL INFILE.') -@click.option('--login-path', type=str, +@click.option('-g', '--login-path', type=str, help='Read this path from the login file.') @click.option('-e', '--execute', type=str, - help='Execute query to the database.') + help='Execute command and quit.') +@click.option('--init-command', type=str, + help='SQL statement to execute after connecting.') +@click.option('--charset', type=str, + help='Character set for MySQL session.') +@click.option('--password-file', type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.') @click.argument('database', default='', nargs=1) def cli(database, user, host, port, socket, password, dbname, - version, prompt, logfile, defaults_group_suffix, defaults_file, - login_path, auto_vertical_output, local_infile, ssl_ca, ssl_capath, - ssl_cert, ssl_key, ssl_cipher, ssl_verify_server_cert, table, csv, - warn, execute, myclirc): + version, verbose, prompt, logfile, defaults_group_suffix, + defaults_file, login_path, auto_vertical_output, local_infile, + ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, + ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, + list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, + ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, + init_command, charset, password_file): + """A MySQL terminal client with auto-completion and syntax highlighting. + + \b + Examples: + - mycli my_database + - mycli -u my_user -h my_host.com my_database + - mycli mysql://my_user@my_host.com:3306/my_database + + """ if version: print('Version:', __version__) @@ -817,9 +1163,35 @@ def cli(database, user, host, port, socket, password, dbname, defaults_file=defaults_file, login_path=login_path, auto_vertical_output=auto_vertical_output, warn=warn, myclirc=myclirc) - + if list_dsn: + try: + alias_dsn = mycli.config['alias_dsn'] + except KeyError as err: + click.secho('Invalid DSNs found in the config file. '\ + 'Please check the "[alias_dsn]" section in myclirc.', + err=True, fg='red') + exit(1) + except Exception as e: + click.secho(str(e), err=True, fg='red') + exit(1) + for alias, value in alias_dsn.items(): + if verbose: + click.secho("{} : {}".format(alias, value)) + else: + click.secho(alias) + sys.exit(0) + if list_ssh_config: + ssh_config = read_ssh_config(ssh_config_path) + for host in ssh_config.get_hostnames(): + if verbose: + host_config = ssh_config.lookup(host) + click.secho("{} : {}".format( + host, host_config.get('hostname'))) + else: + click.secho(host) + sys.exit(0) # Choose which ever one has a valid value. - database = database or dbname + database = dbname or database ssl = { 'ca': ssl_ca and os.path.expanduser(ssl_ca), @@ -832,11 +1204,74 @@ def cli(database, user, host, port, socket, password, dbname, # remove empty ssl options ssl = {k: v for k, v in ssl.items() if v is not None} + + dsn_uri = None + + # Treat the database argument as a DSN alias if we're missing + # other connection information. + if (mycli.config['alias_dsn'] and database and '://' not in database + and not any([user, password, host, port, login_path])): + dsn, database = database, '' + if database and '://' in database: - mycli.connect_uri(database, local_infile, ssl) - else: - mycli.connect(database, user, password, host, port, socket, - local_infile=local_infile, ssl=ssl) + dsn_uri, database = database, '' + + if dsn: + try: + dsn_uri = mycli.config['alias_dsn'][dsn] + except KeyError: + click.secho('Could not find the specified DSN in the config file. ' + 'Please check the "[alias_dsn]" section in your ' + 'myclirc.', err=True, fg='red') + exit(1) + else: + mycli.dsn_alias = dsn + + if dsn_uri: + uri = urlparse(dsn_uri) + if not database: + database = uri.path[1:] # ignore the leading fwd slash + if not user: + user = unquote(uri.username) + if not password and uri.password is not None: + password = unquote(uri.password) + if not host: + host = uri.hostname + if not port: + port = uri.port + + if ssh_config_host: + ssh_config = read_ssh_config( + ssh_config_path + ).lookup(ssh_config_host) + ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') + ssh_user = ssh_user if ssh_user else ssh_config.get('user') + if ssh_config.get('port') and ssh_port == 22: + # port has a default value, overwrite it if it's in the config + ssh_port = int(ssh_config.get('port')) + ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( + 'identityfile', [None])[0] + + ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) + + mycli.connect( + database=database, + user=user, + passwd=password, + host=host, + port=port, + socket=socket, + local_infile=local_infile, + ssl=ssl, + ssh_user=ssh_user, + ssh_host=ssh_host, + ssh_port=ssh_port, + ssh_password=ssh_password, + ssh_key_filename=ssh_key_filename, + init_command=init_command, + charset=charset, + password_file=password_file + ) mycli.logger.debug('Launch Params: \n' '\tdatabase: %r' @@ -848,9 +1283,9 @@ def cli(database, user, host, port, socket, password, dbname, if execute: try: if csv: - mycli.formatter.set_format_name('csv') + mycli.formatter.format_name = 'csv' elif not table: - mycli.formatter.set_format_name('tsv') + mycli.formatter.format_name = 'tsv' mycli.run_query(execute) exit(0) @@ -862,24 +1297,30 @@ def cli(database, user, host, port, socket, password, dbname, mycli.run_cli() else: stdin = click.get_text_stream('stdin') - stdin_text = stdin.read() - try: - sys.stdin = open('/dev/tty') - except FileNotFoundError: - mycli.logger.warning('Unable to open TTY as stdin.') + stdin_text = stdin.read() + except MemoryError: + click.secho('Failed! Ran out of memory.', err=True, fg='red') + click.secho('You might want to try the official mysql client.', err=True, fg='red') + click.secho('Sorry... :(', err=True, fg='red') + exit(1) + + if mycli.destructive_warning and is_destructive(stdin_text): + try: + sys.stdin = open('/dev/tty') + warn_confirmed = confirm_destructive_query(stdin_text) + except (IOError, OSError): + mycli.logger.warning('Unable to open TTY as stdin.') + if not warn_confirmed: + exit(0) - if (mycli.destructive_warning and - confirm_destructive_query(stdin_text) is False): - exit(0) try: new_line = True if csv: - mycli.formatter.set_format_name('csv') - new_line = False + mycli.formatter.format_name = 'csv' elif not table: - mycli.formatter.set_format_name('tsv') + mycli.formatter.format_name = 'tsv' mycli.run_query(stdin_text, new_line=new_line) exit(0) @@ -888,13 +1329,6 @@ def cli(database, user, host, port, socket, password, dbname, exit(1) -def content_exceeds_width(row, width): - # Account for 3 characters between each column - separator_space = (len(row)*3) - # Add 2 columns for a bit of buffer - line_len = sum([len(str(x)) for x in row]) + separator_space + 2 - return line_len > width - def need_completion_refresh(queries): """Determines if the completion needs a refresh by checking if the sql statement is an alter, create, drop or change db.""" @@ -902,11 +1336,12 @@ def need_completion_refresh(queries): try: first_token = query.split()[0] if first_token.lower() in ('alter', 'create', 'use', '\\r', - '\\u', 'connect', 'drop'): + '\\u', 'connect', 'drop', 'rename'): return True except Exception: return False + def need_completion_reset(queries): """Determines if the statement is a database switch such as 'use' or '\\u'. When a database is changed the existing completions must be reset before we @@ -927,56 +1362,59 @@ def is_mutating(status): return False mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', - 'replace', 'truncate', 'load']) + 'replace', 'truncate', 'load', 'rename']) return status.split(None, 1)[0].lower() in mutating + def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: return False return status.split(None, 1)[0].lower() == 'select' -def query_starts_with(query, prefixes): - """Check if the query starts with any item from *prefixes*.""" - prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True) - return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False - -def is_destructive(queries): - keywords = ('drop', 'shutdown', 'delete', 'truncate') - return queries_start_with(queries, keywords) - -def confirm_destructive_query(queries): - """Check if the query is destructive and prompts the user to confirm. - Returns: - None if the query is non-destructive or we can't prompt the user. - True if the query is destructive and the user wants to proceed. - False if the query is destructive and the user doesn't want to proceed. - """ - prompt_text = ("You're about to run a destructive command.\n" - "Do you want to proceed? (y/n)") - if is_destructive(queries) and sys.stdin.isatty(): - return click.prompt(prompt_text, type=bool) - -def quit_command(sql): - return (sql.strip().lower() == 'exit' - or sql.strip().lower() == 'quit' - or sql.strip() == '\q' - or sql.strip() == ':q') - -def thanks_picker(files=()): - for filename in files: - with open(filename, encoding='utf-8') as f: - contents = f.readlines() - - return choice([x.split('*')[1].strip() for x in contents if x.startswith('*')]) +def thanks_picker(): + import mycli + lines = ( + resources.read_text(mycli, 'AUTHORS') + + resources.read_text(mycli, 'SPONSORS') + ).split('\n') + + contents = [] + for line in lines: + m = re.match(r'^ *\* (.*)', line) + if m: + contents.append(m.group(1)) + return choice(contents) + + +@prompt_register('edit-and-execute-command') +def edit_and_execute(event): + """Different from the prompt-toolkit default, we want to have a choice not + to execute a query after editing, hence validate_and_handle=False.""" + buff = event.current_buffer + buff.open_in_editor(validate_and_handle=False) + + +def read_ssh_config(ssh_config_path): + ssh_config = paramiko.config.SSHConfig() + try: + with open(ssh_config_path) as f: + ssh_config.parse(f) + except FileNotFoundError as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + click.secho( + f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', + err=True, fg='red' + ) + sys.exit(1) + else: + return ssh_config if __name__ == "__main__": diff --git a/mycli/myclirc b/mycli/myclirc index 01a11426..c89caa05 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -32,7 +32,7 @@ timing = True # Table format. Possible values: ascii, double, github, # psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html, -# latex, latex_booktabs, textile, moinmoin, jira, expanded, tsv, csv. +# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv. # Recommended: ascii table_format = ascii @@ -41,6 +41,7 @@ table_format = ascii # friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, # fruity. # Screenshots at http://mycli.net/syntax +# Can be further modified in [colors] syntax_style = default # Keybindings: Possible values: emacs, vi. @@ -52,13 +53,22 @@ key_bindings = emacs wider_completion_menu = False # MySQL prompt -# \t - Product type (Percona, MySQL, Mariadb) -# \u - Username -# \h - Hostname of the server +# \D - The full current date # \d - Database name +# \h - Hostname of the server +# \m - Minutes of the current time # \n - Newline +# \P - AM/PM +# \p - Port +# \R - The current time, in 24-hour military time (0-23) +# \r - The current time, standard 12-hour time (1-12) +# \s - Seconds of the current time +# \t - Product type (Percona, MySQL, MariaDB) +# \A - DSN alias name (from the [alias_dsn] section) +# \u - Username +# \x1b[...m - insert ANSI escape sequence prompt = '\t \u@\h:\d> ' -prompt_continuation = '-> ' +prompt_continuation = '->' # Skip intro info on startup and outro info on exit less_chatty = False @@ -70,33 +80,74 @@ login_path_as_host = False # and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.) auto_vertical_output = False +# keyword casing preference. Possible values "lower", "upper", "auto" +keyword_casing = auto + +# disabled pager on startup +enable_pager = True + # Custom colors for the completion menu, toolbar, etc. [colors] -# Completion menus. -Token.Menu.Completions.Completion.Current = 'bg:#00aaaa #000000' -Token.Menu.Completions.Completion = 'bg:#008888 #ffffff' -Token.Menu.Completions.MultiColumnMeta = 'bg:#aaffff #000000' -Token.Menu.Completions.ProgressButton = 'bg:#003333' -Token.Menu.Completions.ProgressBar = 'bg:#00aaaa' - -# Selected text. -Token.SelectedText = '#ffffff bg:#6666aa' - -# Search matches. (reverse-i-search) -Token.SearchMatch = '#ffffff bg:#4444aa' -Token.SearchMatch.Current = '#ffffff bg:#44aa44' - -# The bottom toolbar. -Token.Toolbar = 'bg:#222222 #aaaaaa' -Token.Toolbar.Off = 'bg:#222222 #888888' -Token.Toolbar.On = 'bg:#222222 #ffffff' - -# Search/arg/system toolbars. -Token.Toolbar.Search = 'noinherit bold' -Token.Toolbar.Search.Text = 'nobold' -Token.Toolbar.System = 'noinherit bold' -Token.Toolbar.Arg = 'noinherit bold' -Token.Toolbar.Arg.Text = 'nobold' +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +output.null = "#808080" + +# SQL syntax highlighting overrides +# sql.comment = 'italic #408080' +# sql.comment.multi-line = '' +# sql.comment.single-line = '' +# sql.comment.optimizer-hint = '' +# sql.escape = 'border:#FF0000' +# sql.keyword = 'bold #008000' +# sql.datatype = 'nobold #B00040' +# sql.literal = '' +# sql.literal.date = '' +# sql.symbol = '' +# sql.quoted-schema-object = '' +# sql.quoted-schema-object.escape = '' +# sql.constant = '#880000' +# sql.function = '#0000FF' +# sql.variable = '#19177C' +# sql.number = '#666666' +# sql.number.binary = '' +# sql.number.float = '' +# sql.number.hex = '' +# sql.number.integer = '' +# sql.operator = '#666666' +# sql.punctuation = '' +# sql.string = '#BA2121' +# sql.string.double-quouted = '' +# sql.string.escape = 'bold #BB6622' +# sql.string.single-quoted = '' +# sql.whitespace = '' # Favorite queries. [favorite_queries] + +# Use the -d option to reference a DSN. +# Special characters in passwords and other strings can be escaped with URL encoding. +[alias_dsn] +# example_dsn = mysql://[user[:password]@][host][:port][/dbname] diff --git a/mycli/output_formatter/delimited_output_adapter.py b/mycli/output_formatter/delimited_output_adapter.py deleted file mode 100644 index a01a2843..00000000 --- a/mycli/output_formatter/delimited_output_adapter.py +++ /dev/null @@ -1,28 +0,0 @@ -import contextlib -import csv -try: - from cStringIO import StringIO -except ImportError: - from io import StringIO - -from .preprocessors import override_missing_value, bytes_to_string - -supported_formats = ('csv', 'tsv') -preprocessors = (override_missing_value, bytes_to_string) - - -def adapter(data, headers, table_format='csv', **_): - """Wrap CSV formatting inside a standard function for OutputFormatter.""" - with contextlib.closing(StringIO()) as content: - if table_format == 'csv': - writer = csv.writer(content, delimiter=',') - elif table_format == 'tsv': - writer = csv.writer(content, delimiter='\t') - else: - raise ValueError('Invalid table_format specified.') - - writer.writerow(headers) - for row in data: - writer.writerow(row) - - return content.getvalue() diff --git a/mycli/output_formatter/expanded.py b/mycli/output_formatter/expanded.py deleted file mode 100644 index f77c1ee3..00000000 --- a/mycli/output_formatter/expanded.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Format data into a vertical, expanded table layout.""" - -from __future__ import unicode_literals - - -def get_separator(num): - """Get a row separator for row *num*.""" - return "{divider}[ {n}. row ]{divider}\n".format( - divider='*' * 27, n=num + 1) - - -def format_row(headers, row): - """Format a row.""" - formatted_row = [' | '.join(field) for field in zip(headers, row)] - return '\n'.join(formatted_row) - - -def expanded_table(rows, headers, **_): - """Format *rows* and *headers* as an expanded table. - - The values in *rows* and *headers* must be strings. - - """ - header_len = max([len(x) for x in headers]) - padded_headers = [x.ljust(header_len) for x in headers] - formatted_rows = [format_row(padded_headers, row) for row in rows] - - output = [] - for i, result in enumerate(formatted_rows): - output.append(get_separator(i)) - output.append(result) - output.append('\n') - - return ''.join(output) diff --git a/mycli/output_formatter/output_formatter.py b/mycli/output_formatter/output_formatter.py deleted file mode 100644 index 61e3c8d5..00000000 --- a/mycli/output_formatter/output_formatter.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -"""A generic output formatter interface.""" - -from __future__ import unicode_literals -from collections import namedtuple - -from .expanded import expanded_table -from .preprocessors import (override_missing_value, convert_to_string) - -from . import delimited_output_adapter -from . import tabulate_adapter -from . import terminaltables_adapter - -MISSING_VALUE = '' - -OutputFormatHandler = namedtuple( - 'OutputFormatHandler', - 'format_name preprocessors formatter formatter_args') - - -class OutputFormatter(object): - """A class with a standard interface for various formatting libraries.""" - - _output_formats = {} - - def __init__(self, format_name=None): - """Set the default *format_name*.""" - self._format_name = format_name - - def set_format_name(self, format_name): - """Set the OutputFormatter's default format.""" - if format_name in self.supported_formats(): - self._format_name = format_name - else: - raise ValueError('unrecognized format_name: {}'.format( - format_name)) - - def get_format_name(self): - """Get the OutputFormatter's default format.""" - return self._format_name - - def supported_formats(self): - """Return the supported output format names.""" - return tuple(self._output_formats.keys()) - - @classmethod - def register_new_formatter(cls, format_name, handler, preprocessors=(), - kwargs={}): - """Register a new formatter to format the output.""" - cls._output_formats[format_name] = OutputFormatHandler( - format_name, preprocessors, handler, kwargs) - - def format_output(self, data, headers, format_name=None, **kwargs): - """Format the headers and data using a specific formatter. - - *format_name* must be a formatter available in `supported_formats()`. - - All keyword arguments are passed to the specified formatter. - - """ - format_name = format_name or self._format_name - if format_name not in self.supported_formats(): - raise ValueError('unrecognized format: {}'.format(format_name)) - - (_, preprocessors, formatter, - fkwargs) = self._output_formats[format_name] - fkwargs.update(kwargs) - if preprocessors: - for f in preprocessors: - data, headers = f(data, headers, **fkwargs) - return formatter(data, headers, **fkwargs) - - -OutputFormatter.register_new_formatter('expanded', expanded_table, - (override_missing_value, - convert_to_string), - {'missing_value': MISSING_VALUE}) - -for delimiter_format in delimited_output_adapter.supported_formats: - OutputFormatter.register_new_formatter( - delimiter_format, delimited_output_adapter.adapter, - delimited_output_adapter.preprocessors, - {'table_format': delimiter_format, 'missing_value': MISSING_VALUE}) - -for tabulate_format in tabulate_adapter.supported_formats: - OutputFormatter.register_new_formatter( - tabulate_format, tabulate_adapter.adapter, - tabulate_adapter.preprocessors, - {'table_format': tabulate_format, 'missing_value': MISSING_VALUE}) - -for terminaltables_format in terminaltables_adapter.supported_formats: - OutputFormatter.register_new_formatter( - terminaltables_format, terminaltables_adapter.adapter, - terminaltables_adapter.preprocessors, - {'table_format': terminaltables_format, 'missing_value': MISSING_VALUE}) diff --git a/mycli/output_formatter/preprocessors.py b/mycli/output_formatter/preprocessors.py deleted file mode 100644 index 6f2e459c..00000000 --- a/mycli/output_formatter/preprocessors.py +++ /dev/null @@ -1,87 +0,0 @@ -from decimal import Decimal - -from mycli import encodingutils - - -def to_string(value): - """Convert *value* to a string.""" - if isinstance(value, encodingutils.binary_type): - return encodingutils.bytes_to_string(value) - else: - return encodingutils.text_type(value) - - -def convert_to_string(data, headers, **_): - """Convert all *data* and *headers* to strings.""" - return ([[to_string(v) for v in row] for row in data], - [to_string(h) for h in headers]) - - -def override_missing_value(data, headers, missing_value='', **_): - """Override missing values in the data with *missing_value*.""" - return ([[missing_value if v is None else v for v in row] for row in data], - headers) - - -def bytes_to_string(data, headers, **_): - """Convert all *data* and *headers* bytes to strings.""" - return ([[encodingutils.bytes_to_string(v) for v in row] for row in data], - [encodingutils.bytes_to_string(h) for h in headers]) - - -def intlen(value): - """Find (character) length. - - >>> intlen('11.1') - 2 - >>> intlen('11') - 2 - >>> intlen('1.1') - 1 - - """ - pos = value.find('.') - if pos < 0: - pos = len(value) - return pos - - -def align_decimals(data, headers, **_): - """Align decimals to decimal point.""" - pointpos = len(headers) * [0] - for row in data: - for i, v in enumerate(row): - if isinstance(v, Decimal): - v = encodingutils.text_type(v) - pointpos[i] = max(intlen(v), pointpos[i]) - results = [] - for row in data: - result = [] - for i, v in enumerate(row): - if isinstance(v, Decimal): - v = encodingutils.text_type(v) - result.append((pointpos[i] - intlen(v)) * " " + v) - else: - result.append(v) - results.append(result) - return results, headers - - -def quote_whitespaces(data, headers, quotestyle="'", **_): - """Quote leading/trailing whitespace.""" - quote = len(headers) * [False] - for row in data: - for i, v in enumerate(row): - v = encodingutils.text_type(v) - if v.startswith(' ') or v.endswith(' '): - quote[i] = True - - results = [] - for row in data: - result = [] - for i, v in enumerate(row): - quotation = quotestyle if quote[i] else '' - result.append('{quotestyle}{value}{quotestyle}'.format( - quotestyle=quotation, value=v)) - results.append(result) - return results, headers diff --git a/mycli/output_formatter/tabulate_adapter.py b/mycli/output_formatter/tabulate_adapter.py deleted file mode 100644 index b89dcc0b..00000000 --- a/mycli/output_formatter/tabulate_adapter.py +++ /dev/null @@ -1,22 +0,0 @@ -from mycli.packages import tabulate -from .preprocessors import bytes_to_string, align_decimals - -tabulate.PRESERVE_WHITESPACE = True - -supported_markup_formats = ('mediawiki', 'html', 'latex', 'latex_booktabs', - 'textile', 'moinmoin', 'jira') -supported_table_formats = ('plain', 'simple', 'grid', 'fancy_grid', 'pipe', - 'orgtbl', 'psql', 'rst') -supported_formats = supported_markup_formats + supported_table_formats - -preprocessors = (bytes_to_string, align_decimals) - - -def adapter(data, headers, table_format=None, missing_value='', **_): - """Wrap tabulate inside a standard function for OutputFormatter.""" - kwargs = {'tablefmt': table_format, 'missingval': missing_value, - 'disable_numparse': True} - if table_format in supported_markup_formats: - kwargs.update(numalign=None, stralign=None) - - return tabulate.tabulate(data, headers, **kwargs) diff --git a/mycli/output_formatter/terminaltables_adapter.py b/mycli/output_formatter/terminaltables_adapter.py deleted file mode 100644 index a8f50f98..00000000 --- a/mycli/output_formatter/terminaltables_adapter.py +++ /dev/null @@ -1,25 +0,0 @@ -import terminaltables - -from .preprocessors import (bytes_to_string, align_decimals, - override_missing_value) - -supported_formats = ('ascii', 'double', 'github') -preprocessors = (bytes_to_string, override_missing_value, align_decimals) - - -def adapter(data, headers, table_format=None, **_): - """Wrap terminaltables inside a standard function for OutputFormatter.""" - - table_format_handler = { - 'ascii': terminaltables.AsciiTable, - 'double': terminaltables.DoubleTable, - 'github': terminaltables.GithubFlavoredMarkdownTable, - } - - try: - table = table_format_handler[table_format] - except KeyError: - raise ValueError('unrecognized table format: {}'.format(table_format)) - - t = table([headers] + data) - return t.table diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index b97cadf7..c7db06cb 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -1,19 +1,8 @@ -from __future__ import print_function -import sys import sqlparse from sqlparse.sql import Comparison, Identifier, Where -from sqlparse.compat import text_type from .parseutils import last_word, extract_tables, find_prev_keyword from .special import parse_special_command -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -if PY3: - string_types = str -else: - string_types = basestring - def suggest_type(full_text, text_before_cursor): """Takes the full_text that is typed so far and also the text before the @@ -28,28 +17,32 @@ def suggest_type(full_text, text_before_cursor): identifier = None - # If we've partially typed a word then word_before_cursor won't be an empty - # string. In that case we want to remove the partially typed string before - # sending it to the sqlparser. Otherwise the last token will always be the - # partially typed string which renders the smart completion useless because - # it will always return the list of keywords as completion. - if word_before_cursor: - if word_before_cursor.endswith( - '(') or word_before_cursor.startswith('\\'): - parsed = sqlparse.parse(text_before_cursor) - else: - parsed = sqlparse.parse( - text_before_cursor[:-len(word_before_cursor)]) + # here should be removed once sqlparse has been fixed + try: + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor.endswith( + '(') or word_before_cursor.startswith('\\'): + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse( + text_before_cursor[:-len(word_before_cursor)]) - # word_before_cursor may include a schema qualification, like - # "schema_name.partial_name" or "schema_name.", so parse it - # separately - p = sqlparse.parse(word_before_cursor)[0] + # word_before_cursor may include a schema qualification, like + # "schema_name.partial_name" or "schema_name.", so parse it + # separately + p = sqlparse.parse(word_before_cursor)[0] - if p.tokens and isinstance(p.tokens[0], Identifier): - identifier = p.tokens[0] - else: - parsed = sqlparse.parse(text_before_cursor) + if p.tokens and isinstance(p.tokens[0], Identifier): + identifier = p.tokens[0] + else: + parsed = sqlparse.parse(text_before_cursor) + except (TypeError, AttributeError): + return [{'type': 'keyword'}] if len(parsed) > 1: # Multiple statements being edited -- isolate the current one by @@ -59,7 +52,7 @@ def suggest_type(full_text, text_before_cursor): stmt_start, stmt_end = 0, 0 for statement in parsed: - stmt_len = len(text_type(statement)) + stmt_len = len(str(statement)) stmt_start, stmt_end = stmt_end, stmt_end + stmt_len if stmt_end >= current_pos: @@ -79,7 +72,7 @@ def suggest_type(full_text, text_before_cursor): # Be careful here because trivial whitespace is parsed as a statement, # but the statement won't have a first token tok1 = statement.token_first() - if tok1 and tok1.value == '\\': + if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')): return suggest_special(text_before_cursor) last_token = statement and statement.token_prev(len(statement.tokens))[1] or '' @@ -90,7 +83,7 @@ def suggest_type(full_text, text_before_cursor): def suggest_special(text): text = text.lstrip() - cmd, arg = parse_special_command(text) + cmd, _, arg = parse_special_command(text) if cmd == text: # Trying to complete the special command itself @@ -105,17 +98,20 @@ def suggest_special(text): if cmd in ['\\f', '\\fs', '\\fd']: return [{'type': 'favoritequery'}] - if cmd in ['\\dt']: + if cmd in ['\\dt', '\\dt+']: return [ {'type': 'table', 'schema': []}, {'type': 'view', 'schema': []}, {'type': 'schema'}, ] + elif cmd in ['\\.', 'source']: + return[{'type': 'file_name'}] return [{'type': 'keyword'}, {'type': 'special'}] + def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): - if isinstance(token, string_types): + if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): # If 'token' is a Comparison type such as @@ -191,7 +187,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # We're probably in a function argument list return [{'type': 'column', 'tables': extract_tables(full_text)}] - elif token_v in ('set', 'by', 'distinct'): + elif token_v in ('set', 'order by', 'distinct'): return [{'type': 'column', 'tables': extract_tables(full_text)}] elif token_v == 'as': # Don't suggest anything for an alias @@ -210,16 +206,18 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # Check for a table alias or schema qualification parent = (identifier and identifier.get_parent_name()) or [] + tables = extract_tables(full_text) if parent: - tables = extract_tables(full_text) tables = [t for t in tables if identifies(parent, *t)] return [{'type': 'column', 'tables': tables}, {'type': 'table', 'schema': parent}, {'type': 'view', 'schema': parent}, {'type': 'function', 'schema': parent}] else: - return [{'type': 'column', 'tables': extract_tables(full_text)}, + aliases = [alias or table for (schema, table, alias) in tables] + return [{'type': 'column', 'tables': tables}, {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': aliases}, {'type': 'keyword'}] elif (token_v.endswith('join') and token.is_keyword) or (token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate', @@ -262,7 +260,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier else: # ON # Use table alias if there is one, otherwise the table name - aliases = [t[2] or t[1] for t in tables] + aliases = [alias or table for (schema, table, alias) in tables] suggest = [{'type': 'alias', 'aliases': aliases}] # The lists of 'aliases' could be empty if we're trying to complete diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py new file mode 100644 index 00000000..79fe26dc --- /dev/null +++ b/mycli/packages/filepaths.py @@ -0,0 +1,106 @@ +import os +import platform + + +if os.name == "posix": + if platform.system() == "Darwin": + DEFAULT_SOCKET_DIRS = ("/tmp",) + else: + DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") +else: + DEFAULT_SOCKET_DIRS = () + + +def list_path(root_dir): + """List directory if exists. + + :param root_dir: str + :return: list + + """ + res = [] + if os.path.isdir(root_dir): + for name in os.listdir(root_dir): + res.append(name) + return res + + +def complete_path(curr_dir, last_dir): + """Return the path to complete that matches the last entered component. + + If the last entered component is ~, expanded path would not + match, so return all of the available paths. + + :param curr_dir: str + :param last_dir: str + :return: str + + """ + if not last_dir or curr_dir.startswith(last_dir): + return curr_dir + elif last_dir == '~': + return os.path.join(last_dir, curr_dir) + + +def parse_path(root_dir): + """Split path into head and last component for the completer. + + Also return position where last component starts. + + :param root_dir: str path + :return: tuple of (string, string, int) + + """ + base_dir, last_dir, position = '', '', 0 + if root_dir: + base_dir, last_dir = os.path.split(root_dir) + position = -len(last_dir) if last_dir else 0 + return base_dir, last_dir, position + + +def suggest_path(root_dir): + """List all files and subdirectories in a directory. + + If the directory is not specified, suggest root directory, + user directory, current and parent directory. + + :param root_dir: string: directory to list + :return: list + + """ + if not root_dir: + return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] + + if '~' in root_dir: + root_dir = os.path.expanduser(root_dir) + + if not os.path.exists(root_dir): + root_dir, _ = os.path.split(root_dir) + + return list_path(root_dir) + + +def dir_path_exists(path): + """Check if the directory path exists for a given file. + + For example, for a file /home/user/.cache/mycli/log, check if + /home/user/.cache/mycli exists. + + :param str path: The file path. + :return: Whether or not the directory path exists. + + """ + return os.path.exists(os.path.dirname(path)) + + +def guess_socket_location(): + """Try to guess the location of the default mysql socket file.""" + socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) + for directory in socket_dirs: + for r, dirs, files in os.walk(directory, topdown=True): + for filename in files: + name, ext = os.path.splitext(filename) + if name.startswith("mysql") and ext in ('.socket', '.sock'): + return os.path.join(r, filename) + dirs[:] = [d for d in dirs if d.startswith("mysql")] + return None diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py new file mode 100644 index 00000000..045b00ea --- /dev/null +++ b/mycli/packages/paramiko_stub/__init__.py @@ -0,0 +1,28 @@ +"""A module to import instead of paramiko when it is not available (to avoid +checking for paramiko all over the place). + +When paramiko is first envoked, it simply shuts down mycli, telling +user they either have to install paramiko or should not use SSH +features. + +""" + + +class Paramiko: + def __getattr__(self, name): + import sys + from textwrap import dedent + print(dedent(""" + To enable certain SSH features you need to install paramiko: + + pip install paramiko + + It is required for the following configuration options: + --list-ssh-config + --ssh-config-host + --ssh-host + """)) + sys.exit(1) + + +paramiko = Paramiko() diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 7f848ad8..d47f59a5 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -1,4 +1,3 @@ -from __future__ import print_function import re import sqlparse from sqlparse.sql import IdentifierList, Identifier, Function @@ -12,11 +11,12 @@ # This matches everything except spaces, parens, colon, comma, and period 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), # This matches everything except a space. - 'all_punctuations': re.compile('([^\s]+)$'), - } + 'all_punctuations': re.compile(r'([^\s]+)$'), +} + def last_word(text, include='alphanum_underscore'): - """ + r""" Find the last word in a sentence. >>> last_word('abc') @@ -80,7 +80,14 @@ def extract_from_part(parsed, stop_at_punctuation=True): for x in extract_from_part(item, stop_at_punctuation): yield x elif stop_at_punctuation and item.ttype is Punctuation: - raise StopIteration + return + # Multiple JOINs in the same query won't work properly since + # "ON" is a keyword and will trigger the next elif condition. + # So instead of stooping the loop when finding an "ON" skip it + # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi' + elif item.ttype is Keyword and item.value.upper() == 'ON': + tbl_prefix_seen = False + continue # An incomplete nested select won't be recognized correctly as a # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # the second FROM to trigger this elif condition resulting in a @@ -92,7 +99,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): elif item.ttype is Keyword and ( not item.value.upper() == 'FROM') and ( not item.value.upper().endswith('JOIN')): - raise StopIteration + return else: yield item elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and @@ -187,6 +194,73 @@ def find_prev_keyword(sql): return None, '' + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def query_has_where_clause(query): + """Check if the query contains a where-clause.""" + return any( + isinstance(token, sqlparse.sql.Where) + for token_list in sqlparse.parse(query) + for token in token_list + ) + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + for query in sqlparse.split(queries): + if query: + if query_starts_with(query, keywords) is True: + return True + elif query_starts_with( + query, ['update'] + ) is True and not query_has_where_clause(query): + return True + + return False + + if __name__ == '__main__': sql = 'select * from (select t. from tabl t' print (extract_tables(sql)) + + +def is_dropping_database(queries, dbname): + """Determine if the query is dropping a specific database.""" + result = False + if dbname is None: + return False + + def normalize_db_name(db): + return db.lower().strip('`"') + + dbname = normalize_db_name(dbname) + + for query in sqlparse.parse(queries): + keywords = [t for t in query.tokens if t.is_keyword] + if len(keywords) < 2: + continue + if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( + "database", + "schema", + ): + database_token = next( + (t for t in query.tokens if isinstance(t, Identifier)), None + ) + if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: + result = keywords[0].normalized == "DROP" + return result diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py new file mode 100644 index 00000000..fb1e431a --- /dev/null +++ b/mycli/packages/prompt_utils.py @@ -0,0 +1,54 @@ +import sys +import click +from .parseutils import is_destructive + + +class ConfirmBoolParamType(click.ParamType): + name = 'confirmation' + + def convert(self, value, param, ctx): + if isinstance(value, bool): + return bool(value) + value = value.lower() + if value in ('yes', 'y'): + return True + elif value in ('no', 'n'): + return False + self.fail('%s is not a valid boolean' % value, param, ctx) + + def __repr__(self): + return 'BOOL' + + +BOOLEAN_TYPE = ConfirmBoolParamType() + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ("You're about to run a destructive command.\n" + "Do you want to proceed? (y/n)") + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=BOOLEAN_TYPE) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index 997d3d0a..45d70690 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -5,23 +5,35 @@ from mycli.packages.special import iocommands from mycli.packages.special.utils import format_uptime from .main import special_command, RAW_QUERY, PARSED_QUERY +from pymysql import ProgrammingError log = logging.getLogger(__name__) -@special_command('\\dt', '\\dt [table]', 'List or describe tables.', arg_type=PARSED_QUERY, case_sensitive=True) -def list_tables(cur, arg=None, arg_type=PARSED_QUERY): + +@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.', + arg_type=PARSED_QUERY, case_sensitive=True) +def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): if arg: query = 'SHOW FIELDS FROM {0}'.format(arg) else: query = 'SHOW TABLES' log.debug(query) cur.execute(query) + tables = cur.fetchall() + status = '' if cur.description: headers = [x[0] for x in cur.description] - return [(None, cur, headers, '')] else: return [(None, None, None, '')] + if verbose and arg: + query = 'SHOW CREATE TABLE {0}'.format(arg) + log.debug(query) + cur.execute(query) + status = cur.fetchone()[1] + + return [(None, tables, headers, status)] + @special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True) def list_databases(cur, **_): query = 'SHOW DATABASES' @@ -38,7 +50,13 @@ def list_databases(cur, **_): def status(cur, **_): query = 'SHOW GLOBAL STATUS;' log.debug(query) - cur.execute(query) + try: + cur.execute(query) + except ProgrammingError: + # Fallback in case query fail, as it does with Mysql 4 + query = 'SHOW STATUS;' + log.debug(query) + cur.execute(query) status = dict(cur.fetchall()) query = 'SHOW GLOBAL VARIABLES;' @@ -46,6 +64,14 @@ def status(cur, **_): cur.execute(query) variables = dict(cur.fetchall()) + # prepare in case keys are bytes, as with Python 3 and Mysql 4 + if (isinstance(list(variables)[0], bytes) and + isinstance(list(status)[0], bytes)): + variables = {k.decode('utf-8'): v.decode('utf-8') for k, v + in variables.items()} + status = {k.decode('utf-8'): v.decode('utf-8') for k, v + in status.items()} + # Create output buffers. title = [] output = [] @@ -109,20 +135,25 @@ def status(cur, **_): else: output.append(('UNIX socket:', variables['socket'])) - output.append(('Uptime:', format_uptime(status['Uptime']))) - - # Print the current server statistics. - stats = [] - stats.append('Connections: {0}'.format(status['Threads_connected'])) - stats.append('Queries: {0}'.format(status['Queries'])) - stats.append('Slow queries: {0}'.format(status['Slow_queries'])) - stats.append('Opens: {0}'.format(status['Opened_tables'])) - stats.append('Flush tables: {0}'.format(status['Flush_commands'])) - stats.append('Open tables: {0}'.format(status['Open_tables'])) - queries_per_second = int(status['Queries']) / int(status['Uptime']) - stats.append('Queries per second avg: {:.3f}'.format(queries_per_second)) - stats = ' '.join(stats) - footer.append('\n' + stats) + if 'Uptime' in status: + output.append(('Uptime:', format_uptime(status['Uptime']))) + + if 'Threads_connected' in status: + # Print the current server statistics. + stats = [] + stats.append('Connections: {0}'.format(status['Threads_connected'])) + if 'Queries' in status: + stats.append('Queries: {0}'.format(status['Queries'])) + stats.append('Slow queries: {0}'.format(status['Slow_queries'])) + stats.append('Opens: {0}'.format(status['Opened_tables'])) + stats.append('Flush tables: {0}'.format(status['Flush_commands'])) + stats.append('Open tables: {0}'.format(status['Open_tables'])) + if 'Queries' in status: + queries_per_second = int(status['Queries']) / int(status['Uptime']) + stats.append('Queries per second avg: {:.3f}'.format( + queries_per_second)) + stats = ' '.join(stats) + footer.append('\n' + stats) footer.append('--------------') return [('\n'.join(title), output, '', '\n'.join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py new file mode 100644 index 00000000..994b134b --- /dev/null +++ b/mycli/packages/special/delimitercommand.py @@ -0,0 +1,80 @@ +import re +import sqlparse + + +class DelimiterCommand(object): + def __init__(self): + self._delimiter = ';' + + def _split(self, sql): + """Temporary workaround until sqlparse.split() learns about custom + delimiters.""" + + placeholder = "\ufffc" # unicode object replacement character + + if self._delimiter == ';': + return sqlparse.split(sql) + + # We must find a string that original sql does not contain. + # Most likely, our placeholder is enough, but if not, keep looking + while placeholder in sql: + placeholder += placeholder[0] + sql = sql.replace(';', placeholder) + sql = sql.replace(self._delimiter, ';') + + split = sqlparse.split(sql) + + return [ + stmt.replace(';', self._delimiter).replace(placeholder, ';') + for stmt in split + ] + + def queries_iter(self, input): + """Iterate over queries in the input string.""" + + queries = self._split(input) + while queries: + for sql in queries: + delimiter = self._delimiter + sql = queries.pop(0) + if sql.endswith(delimiter): + trailing_delimiter = True + sql = sql.strip(delimiter) + else: + trailing_delimiter = False + + yield sql + + # if the delimiter was changed by the last command, + # re-split everything, and if we previously stripped + # the delimiter, append it to the end + if self._delimiter != delimiter: + combined_statement = ' '.join([sql] + queries) + if trailing_delimiter: + combined_statement += delimiter + queries = self._split(combined_statement)[1:] + + def set(self, arg, **_): + """Change delimiter. + + Since `arg` is everything that follows the DELIMITER token + after sqlparse (it may include other statements separated by + the new delimiter), we want to set the delimiter to the first + word of it. + + """ + match = arg and re.search(r'[^\s]+', arg) + if not match: + message = 'Missing required argument, delimiter' + return [(None, None, None, message)] + + delimiter = match.group() + if delimiter.lower() == 'delimiter': + return [(None, None, None, 'Invalid delimiter "delimiter"')] + + self._delimiter = delimiter + return [(None, None, None, "Changed delimiter to {}".format(delimiter))] + + @property + def current(self): + return self._delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py index bec14ecb..0b91400e 100644 --- a/mycli/packages/special/favoritequeries.py +++ b/mycli/packages/special/favoritequeries.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - class FavoriteQueries(object): section_name = 'favorite_queries' @@ -26,7 +23,7 @@ class FavoriteQueries(object): ╒════════╤════════╕ │ a │ b │ ╞════════╪════════╡ - │ 日本語 │ 日本語 │ + │ 日本語 │ 日本語 │ ╘════════╧════════╛ # Delete a favorite query. @@ -34,9 +31,16 @@ class FavoriteQueries(object): simple: Deleted ''' + # Class-level variable, for convenience to use as a singleton. + instance = None + def __init__(self, config): self.config = config + @classmethod + def from_config(cls, config): + return FavoriteQueries(config) + def list(self): return self.config.get(self.section_name, []) @@ -44,6 +48,7 @@ def get(self, name): return self.config.get(self.section_name, {}).get(name, None) def save(self, name, query): + self.config.encoding = 'utf-8' if self.section_name not in self.config: self.config[self.section_name] = {} self.config[self.section_name][name] = query @@ -56,6 +61,3 @@ def delete(self, name): return '%s: Not Found.' % name self.config.write() return '%s: Deleted' % name - -from ...config import read_config_file -favoritequeries = FavoriteQueries(read_config_file('~/.myclirc')) diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 4d4fcc65..01f3c7ba 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -3,20 +3,31 @@ import locale import logging import subprocess +import shlex from io import open +from time import sleep import click +import pyperclip import sqlparse from . import export from .main import special_command, NO_QUERY, PARSED_QUERY -from .favoritequeries import favoritequeries +from .favoritequeries import FavoriteQueries +from .delimitercommand import DelimiterCommand from .utils import handle_cd_command +from mycli.packages.prompt_utils import confirm_destructive_query TIMING_ENABLED = False use_expanded_output = False PAGER_ENABLED = True tee_file = None +once_file = None +written_to_once_file = False +pipe_once_process = None +written_to_pipe_once_process = False +delimiter_command = DelimiterCommand() + @export def set_timing_enabled(val): @@ -28,12 +39,15 @@ def set_pager_enabled(val): global PAGER_ENABLED PAGER_ENABLED = val + @export def is_pager_enabled(): return PAGER_ENABLED @export -@special_command('pager', '\\P [command]', 'Set PAGER. Print the query results via PAGER', arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) +@special_command('pager', '\\P [command]', + 'Set PAGER. Print the query results via PAGER.', + arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) def set_pager(arg, **_): if arg: os.environ['PAGER'] = arg @@ -77,12 +91,6 @@ def set_expanded_output(val): def is_expanded_output(): return use_expanded_output -def quit(*args): - raise NotImplementedError - -def stub(*args): - raise NotImplementedError - _logger = logging.getLogger(__name__) @export @@ -101,36 +109,45 @@ def get_filename(sql): command, _, filename = sql.partition(' ') return filename.strip() or None -@export -def open_external_editor(filename=None, sql=''): - """ - Open external editor, wait for the user to type in his query, - return the query. - :return: list with one tuple, query as first element. - """ +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" sql = sql.strip() # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". - pattern = re.compile('(^\\\e|\\\e$)') + pattern = re.compile(r'(^\\e|\\e$)') while pattern.search(sql): sql = pattern.sub('', sql) + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + message = None filename = filename.strip().split(' ', 1)[0] if filename else None + sql = sql or '' MARKER = '# Type your query above this line.\n' # Populate the editor buffer with the partial sql (if available) and a # placeholder comment. - query = click.edit(sql + '\n\n' + MARKER, filename=filename, - extension='.sql') + query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), + filename=filename, extension='.sql') if filename: try: - with open(filename, encoding='utf-8') as f: + with open(filename) as f: query = f.read() except IOError: message = 'Error reading file: %s.' % filename @@ -144,47 +161,114 @@ def open_external_editor(filename=None, sql=''): return (query, message) -@special_command('\\f', '\\f [name]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) -def execute_favorite_query(cur, arg): + +@export +def clip_command(command): + """Is this a clip command? + + :param command: string + + """ + # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check + # for both conditions. + return command.strip().endswith('\\clip') or command.strip().startswith('\\clip') + + +@export +def get_clip_query(sql): + """Get the query part of a clip command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\clip') is that it strips characters, + # not a substring. So it'll strip "c" in the end of the sql also! + pattern = re.compile(r'(^\\clip|\\clip$)') + while pattern.search(sql): + sql = pattern.sub('', sql) + + return sql + + +@export +def copy_query_to_clipboard(sql=None): + """Send query to the clipboard.""" + + sql = sql or '' + message = None + + try: + pyperclip.copy(u'{sql}'.format(sql=sql)) + except RuntimeError as e: + message = 'Error clipping query: %s.' % e.strerror + + return message + + +@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) +def execute_favorite_query(cur, arg, **_): """Returns (title, rows, headers, status)""" if arg == '': for result in list_favorite_queries(): yield result - query = favoritequeries.get(arg) + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(' ') + args = shlex.split(arg_str) + + query = FavoriteQueries.instance.get(name) if query is None: - message = "No favorite query: %s" % (arg) + message = "No favorite query: %s" % (name) yield (None, None, None, message) else: - for sql in sqlparse.split(query): - sql = sql.rstrip(';') - title = '> %s' % (sql) - cur.execute(sql) - if cur.description: - headers = [x[0] for x in cur.description] - yield (title, cur, headers, None) - else: - yield (title, None, None, None) + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(';') + title = '> %s' % (sql) + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) def list_favorite_queries(): """List of all favorite queries. Returns (title, rows, headers, status)""" headers = ["Name", "Query"] - rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()] + rows = [(r, FavoriteQueries.instance.get(r)) + for r in FavoriteQueries.instance.list()] if not rows: - status = '\nNo favorite queries found.' + favoritequeries.usage + status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage else: status = '' return [('', rows, headers, status)] + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + subst_var = '$' + str(idx + 1) + if subst_var not in query: + return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + + query = query.replace(subst_var, val) + + match = re.search(r'\$\d+', query) + if match: + return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + + return [query, None] + @special_command('\\fs', '\\fs name query', 'Save a favorite query.') def save_favorite_query(arg, **_): """Save a new favorite query. Returns (title, rows, headers, status)""" - usage = 'Syntax: \\fs name query.\n\n' + favoritequeries.usage + usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] @@ -195,30 +279,30 @@ def save_favorite_query(arg, **_): return [(None, None, None, usage + 'Err: Both name and query are required.')] - favoritequeries.save(name, query) + FavoriteQueries.instance.save(name, query) return [(None, None, None, "Saved.")] + @special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') def delete_favorite_query(arg, **_): - """Delete an existing favorite query. - """ - usage = 'Syntax: \\fd name.\n\n' + favoritequeries.usage + """Delete an existing favorite query.""" + usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] - status = favoritequeries.delete(arg) + status = FavoriteQueries.instance.delete(arg) return [(None, None, None, status)] -@special_command('system', 'system [command]', 'Execute a system commmand.') + +@special_command('system', 'system [command]', + 'Execute a system shell commmand.') def execute_system_command(arg, **_): - """ - Execute a system command. - """ + """Execute a system shell command.""" usage = "Syntax: system [command].\n" if not arg: - return [(None, None, None, usage)] + return [(None, None, None, usage)] try: command = arg.strip() @@ -242,10 +326,8 @@ def execute_system_command(arg, **_): except OSError as e: return [(None, None, None, 'OSError: %s' % e.strerror)] -@special_command('tee', 'tee [-o] filename', - 'write to an output file (optionally overwrite using -o)') -def set_tee(arg, **_): - global tee_file + +def parseargfile(arg): if arg.startswith('-o '): mode = "w" filename = arg[3:] @@ -256,8 +338,16 @@ def set_tee(arg, **_): if not filename: raise TypeError('You must provide a filename.') + return {'file': os.path.expanduser(filename), 'mode': mode} + + +@special_command('tee', 'tee [-o] filename', + 'Append all results to an output file (overwrite using -o).') +def set_tee(arg, **_): + global tee_file + try: - tee_file = open(filename, mode) + tee_file = open(**parseargfile(arg)) except (IOError, OSError) as e: raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) @@ -270,7 +360,8 @@ def close_tee(): tee_file.close() tee_file = None -@special_command('notee', 'notee', 'stop writing to an output file') + +@special_command('notee', 'notee', 'Stop writing results to an output file.') def no_tee(arg, **_): close_tee() return [(None, None, None, "")] @@ -279,6 +370,174 @@ def no_tee(arg, **_): def write_tee(output): global tee_file if tee_file: - tee_file.write(output) - tee_file.write(u"\n") + click.echo(output, file=tee_file, nl=False) + click.echo(u'\n', file=tee_file, nl=False) tee_file.flush() + + +@special_command('\\once', '\\o [-o] filename', + 'Append next result to an output file (overwrite using -o).', + aliases=('\\o', )) +def set_once(arg, **_): + global once_file, written_to_once_file + + try: + once_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format( + e.filename, e.strerror)) + written_to_once_file = False + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + click.echo(output, file=once_file, nl=False) + click.echo(u"\n", file=once_file, nl=False) + once_file.flush() + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file, written_to_once_file + if written_to_once_file and once_file: + once_file.close() + once_file = None + + +@special_command('\\pipe_once', '\\| command', + 'Send next result to a subprocess.', + aliases=('\\|', )) +def set_pipe_once(arg, **_): + global pipe_once_process, written_to_pipe_once_process + pipe_once_cmd = shlex.split(arg) + if len(pipe_once_cmd) == 0: + raise OSError("pipe_once requires a command") + written_to_pipe_once_process = False + pipe_once_process = subprocess.Popen(pipe_once_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + encoding='UTF-8', + universal_newlines=True) + return [(None, None, None, "")] + + +@export +def write_pipe_once(output): + global pipe_once_process, written_to_pipe_once_process + if output and pipe_once_process: + try: + click.echo(output, file=pipe_once_process.stdin, nl=False) + click.echo(u"\n", file=pipe_once_process.stdin, nl=False) + except (IOError, OSError) as e: + pipe_once_process.terminate() + raise OSError( + "Failed writing to pipe_once subprocess: {}".format(e.strerror)) + written_to_pipe_once_process = True + + +@export +def unset_pipe_once_if_written(): + """Unset the pipe_once cmd, if it has been written to.""" + global pipe_once_process, written_to_pipe_once_process + if written_to_pipe_once_process: + (stdout_data, stderr_data) = pipe_once_process.communicate() + if len(stdout_data) > 0: + print(stdout_data.rstrip(u"\n")) + if len(stderr_data) > 0: + print(stderr_data.rstrip(u"\n")) + pipe_once_process = None + written_to_pipe_once_process = False + + +@special_command( + 'watch', + 'watch [seconds] [-c] query', + 'Executes the query every [seconds] seconds (by default 5).' +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + return + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + return + (current_arg, _, arg) = arg.partition(' ') + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == '-c': + clear_screen = True + continue + statement = '{0!s} {1!s}'.format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + return + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs['cur'] + sql_list = [ + (sql.rstrip(';'), "> {0!s}".format(sql)) + for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + return + finally: + set_pager_enabled(old_pager_enabled) + + +@export +@special_command('delimiter', None, 'Change SQL delimiter.') +def set_delimiter(arg, **_): + return delimiter_command.set(arg) + + +@export +def get_current_delimiter(): + return delimiter_command.current + + +@export +def split_queries(input): + for query in delimiter_command.queries_iter(input): + yield query diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index ee714f5a..ab04f30d 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -22,7 +22,9 @@ class CommandNotFound(Exception): @export def parse_special_command(sql): command, _, arg = sql.partition(' ') - return (command, arg.strip()) + verbose = '+' in command + command = command.strip().replace('+', '') + return (command, verbose, arg.strip()) @export def special_command(command, shortcut, description, arg_type=PARSED_QUERY, @@ -50,7 +52,7 @@ def execute(cur, sql): """Execute a special command and return the results. If the special command is not supported a KeyError will be raised. """ - command, arg = parse_special_command(sql) + command, verbose, arg = parse_special_command(sql) if (command not in COMMANDS) and (command.lower() not in COMMANDS): raise CommandNotFound @@ -70,7 +72,7 @@ def execute(cur, sql): if special_cmd.arg_type == NO_QUERY: return special_cmd.handler() elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, arg=arg) + return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) elif special_cmd.arg_type == RAW_QUERY: return special_cmd.handler(cur=cur, query=sql) @@ -101,9 +103,18 @@ def show_keyword_help(cur, arg): else: return [(None, None, None, 'No help found for {0}.'.format(keyword))] + @special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', )) @special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY) -@special_command('\\e', '\\e', 'Edit command with editor. (uses $EDITOR)', arg_type=NO_QUERY, case_sensitive=True) -@special_command('\\G', '\\G', 'Display results vertically.', arg_type=NO_QUERY, case_sensitive=True) +def quit(*_args): + raise EOFError + + +@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', + arg_type=NO_QUERY, case_sensitive=True) +@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.', + arg_type=NO_QUERY, case_sensitive=True) +@special_command('\\G', '\\G', 'Display current query results vertically.', + arg_type=NO_QUERY, case_sensitive=True) def stub(): raise NotImplementedError diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py new file mode 100644 index 00000000..e6587bd3 --- /dev/null +++ b/mycli/packages/tabular_output/sql_format.py @@ -0,0 +1,62 @@ +"""Format adapter for sql.""" + +from mycli.packages.parseutils import extract_tables + +supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', + 'sql-update-2', ) + +preprocessors = () + + +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return formatter.mycli.sqlexecute.conn.escape(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "`DUAL`" + if table_format == 'sql-insert': + h = "`, `".join(headers) + yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) + for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith('sql-update'): + s = table_format.split('-') + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield "UPDATE {} SET".format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v)) + if prefix == " ": + prefix = ", " + f = "`{}` = {}" + where = (f.format(headers[i], escape_for_sql_statement( + d[i])) for i in range(keys)) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {'table_format': sql_format}) diff --git a/mycli/packages/tabulate.py b/mycli/packages/tabulate.py deleted file mode 100644 index 1e67cea7..00000000 --- a/mycli/packages/tabulate.py +++ /dev/null @@ -1,1432 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Pretty-print tabular data.""" - -from __future__ import print_function -from __future__ import unicode_literals -from collections import namedtuple, Iterable -from platform import python_version_tuple -import re - - -if python_version_tuple()[0] < "3": - from itertools import izip_longest - from functools import partial - _none_type = type(None) - _bool_type = bool - _int_type = int - _long_type = long - _float_type = float - _text_type = unicode - _binary_type = str - - def _is_file(f): - return isinstance(f, file) - -else: - from itertools import zip_longest as izip_longest - from functools import reduce, partial - _none_type = type(None) - _bool_type = bool - _int_type = int - _long_type = int - _float_type = float - _text_type = str - _binary_type = bytes - basestring = str - - import io - - def _is_file(f): - return isinstance(f, io.IOBase) - -try: - import wcwidth # optional wide-character (CJK) support -except ImportError: - wcwidth = None - - -__all__ = ["tabulate", "tabulate_formats"] -__version__ = "0.8.0" - - -# minimum extra space in headers -MIN_PADDING = 2 - -PRESERVE_WHITESPACE = False - -_DEFAULT_FLOATFMT = "g" -_DEFAULT_MISSINGVAL = "" - - -# if True, enable wide-character (CJK) support -WIDE_CHARS_MODE = wcwidth is not None - - -Line = namedtuple("Line", ["begin", "hline", "sep", "end"]) - - -DataRow = namedtuple("DataRow", ["begin", "sep", "end"]) - - -# A table structure is suppposed to be: -# -# --- lineabove --------- -# headerrow -# --- linebelowheader --- -# datarow -# --- linebewteenrows --- -# ... (more datarows) ... -# --- linebewteenrows --- -# last datarow -# --- linebelow --------- -# -# TableFormat's line* elements can be -# -# - either None, if the element is not used, -# - or a Line tuple, -# - or a function: [col_widths], [col_alignments] -> string. -# -# TableFormat's *row elements can be -# -# - either None, if the element is not used, -# - or a DataRow tuple, -# - or a function: [cell_values], [col_widths], [col_alignments] -> string. -# -# padding (an integer) is the amount of white space around data values. -# -# with_header_hide: -# -# - either None, to display all table elements unconditionally, -# - or a list of elements not to be displayed if the table has column -# headers. -# -TableFormat = namedtuple("TableFormat", ["lineabove", "linebelowheader", - "linebetweenrows", "linebelow", - "headerrow", "datarow", - "padding", "with_header_hide"]) - - -def _pipe_segment_with_colons(align, colwidth): - """Return a segment of a horizontal line with optional colons which - indicate column's alignment (as in `pipe` output format).""" - w = colwidth - if align in ["right", "decimal"]: - return ('-' * (w - 1)) + ":" - elif align == "center": - return ":" + ('-' * (w - 2)) + ":" - elif align == "left": - return ":" + ('-' * (w - 1)) - else: - return '-' * w - - -def _pipe_line_with_colons(colwidths, colaligns): - """Return a horizontal line with optional colons to indicate column's - alignment (as in `pipe` output format).""" - segments = [_pipe_segment_with_colons(a, w) for a, w in - zip(colaligns, colwidths)] - return "|" + "|".join(segments) + "|" - - -def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns): - alignment = {"left": '', - "right": 'align="right"| ', - "center": 'align="center"| ', - "decimal": 'align="right"| '} - # hard-coded padding _around_ align attribute and value together - # rather than padding parameter which affects only the value - values_with_attrs = [' ' + alignment.get(a, '') + c + ' ' - for c, a in zip(cell_values, colaligns)] - colsep = separator*2 - return (separator + colsep.join(values_with_attrs)).rstrip() - - -def _textile_row_with_attrs(cell_values, colwidths, colaligns): - cell_values[0] += ' ' - alignment = {"left": "<.", "right": ">.", "center": "=.", "decimal": ">."} - values = (alignment.get(a, '') + v for a, v in zip(colaligns, cell_values)) - return '|' + '|'.join(values) + '|' - - -def _html_begin_table_without_header(colwidths_ignore, colaligns_ignore): - # this table header will be suppressed if there is a header row - return "\n".join(["", ""]) - - -def _html_row_with_attrs(celltag, cell_values, colwidths, colaligns): - alignment = {"left": '', - "right": ' style="text-align: right;"', - "center": ' style="text-align: center;"', - "decimal": ' style="text-align: right;"'} - values_with_attrs = ["<{0}{1}>{2}".format( - celltag, alignment.get(a, ''), c) for c, a in - zip(cell_values, colaligns)] - rowhtml = "" + "".join(values_with_attrs).rstrip() + "" - if celltag == "th": # it's a header row, create a new table header - rowhtml = "\n".join(["
", - "", - rowhtml, - "", - ""]) - return rowhtml - - -def _moin_row_with_attrs(celltag, cell_values, colwidths, colaligns, - header=''): - alignment = {"left": '', - "right": '', - "center": '', - "decimal": ''} - values_with_attrs = ["{0}{1} {2} ".format(celltag, - alignment.get(a, ''), - header + c + header) - for c, a in zip(cell_values, colaligns)] - return "".join(values_with_attrs) + "||" - - -def _latex_line_begin_tabular(colwidths, colaligns, booktabs=False): - alignment = {"left": "l", "right": "r", "center": "c", "decimal": "r"} - tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns]) - return "\n".join(["\\begin{tabular}{" + tabular_columns_fmt + "}", - "\\toprule" if booktabs else "\hline"]) - - -LATEX_ESCAPE_RULES = {r"&": r"\&", r"%": r"\%", r"$": r"\$", r"#": r"\#", - r"_": r"\_", r"^": r"\^{}", r"{": r"\{", r"}": r"\}", - r"~": r"\textasciitilde{}", "\\": r"\textbackslash{}", - r"<": r"\ensuremath{<}", r">": r"\ensuremath{>}"} - - -def _latex_row(cell_values, colwidths, colaligns, escrules=LATEX_ESCAPE_RULES): - def escape_char(c): - return escrules.get(c, c) - escaped_values = ["".join(map(escape_char, cell)) for cell in cell_values] - rowfmt = DataRow("", "&", "\\\\") - return _build_simple_row(escaped_values, rowfmt) - - -def _rst_escape_first_column(rows, headers): - def escape_empty(val): - if isinstance(val, (_text_type, _binary_type)) and val.strip() is "": - return ".." - else: - return val - new_headers = list(headers) - new_rows = [] - if headers: - new_headers[0] = escape_empty(headers[0]) - for row in rows: - new_row = list(row) - if new_row: - new_row[0] = escape_empty(row[0]) - new_rows.append(new_row) - return new_rows, new_headers - - -_table_formats = {"simple": - TableFormat( - lineabove=Line("", "-", " ", ""), - linebelowheader=Line("", "-", " ", ""), - linebetweenrows=None, - linebelow=Line("", "-", " ", ""), - headerrow=DataRow("", " ", ""), - datarow=DataRow("", " ", ""), - padding=0, - with_header_hide=["lineabove", "linebelow"]), - "plain": - TableFormat( - lineabove=None, linebelowheader=None, - linebetweenrows=None, linebelow=None, - headerrow=DataRow("", " ", ""), - datarow=DataRow("", " ", ""), - padding=0, with_header_hide=None), - "grid": - TableFormat( - lineabove=Line("+", "-", "+", "+"), - linebelowheader=Line("+", "=", "+", "+"), - linebetweenrows=Line("+", "-", "+", "+"), - linebelow=Line("+", "-", "+", "+"), - headerrow=DataRow("|", "|", "|"), - datarow=DataRow("|", "|", "|"), - padding=1, with_header_hide=None), - "fancy_grid": - TableFormat( - lineabove=Line("╒", "═", "╤", "╕"), - linebelowheader=Line("╞", "═", "╪", "╡"), - linebetweenrows=Line("├", "─", "┼", "┤"), - linebelow=Line("╘", "═", "╧", "╛"), - headerrow=DataRow("│", "│", "│"), - datarow=DataRow("│", "│", "│"), - padding=1, with_header_hide=None), - "pipe": - TableFormat( - lineabove=_pipe_line_with_colons, - linebelowheader=_pipe_line_with_colons, - linebetweenrows=None, - linebelow=None, - headerrow=DataRow("|", "|", "|"), - datarow=DataRow("|", "|", "|"), - padding=1, - with_header_hide=["lineabove"]), - "orgtbl": - TableFormat( - lineabove=None, - linebelowheader=Line("|", "-", "+", "|"), - linebetweenrows=None, - linebelow=None, - headerrow=DataRow("|", "|", "|"), - datarow=DataRow("|", "|", "|"), - padding=1, with_header_hide=None), - "jira": - TableFormat( - lineabove=None, - linebelowheader=None, - linebetweenrows=None, - linebelow=None, - headerrow=DataRow("||", "||", "||"), - datarow=DataRow("|", "|", "|"), - padding=1, with_header_hide=None), - "psql": - TableFormat( - lineabove=Line("+", "-", "+", "+"), - linebelowheader=Line("|", "-", "+", "|"), - linebetweenrows=None, - linebelow=Line("+", "-", "+", "+"), - headerrow=DataRow("|", "|", "|"), - datarow=DataRow("|", "|", "|"), - padding=1, with_header_hide=None), - "rst": - TableFormat( - lineabove=Line("", "=", " ", ""), - linebelowheader=Line("", "=", " ", ""), - linebetweenrows=None, - linebelow=Line("", "=", " ", ""), - headerrow=DataRow("", " ", ""), - datarow=DataRow("", " ", ""), - padding=0, with_header_hide=None), - "mediawiki": - TableFormat(lineabove=Line( - "{| class=\"wikitable\" style=\"text-align: left;\"", - "", "", "\n|+ \n|-"), - linebelowheader=Line("|-", "", "", ""), - linebetweenrows=Line("|-", "", "", ""), - linebelow=Line("|}", "", "", ""), - headerrow=partial(_mediawiki_row_with_attrs, "!"), - datarow=partial(_mediawiki_row_with_attrs, "|"), - padding=0, with_header_hide=None), - "moinmoin": - TableFormat( - lineabove=None, - linebelowheader=None, - linebetweenrows=None, - linebelow=None, - headerrow=partial(_moin_row_with_attrs, "||", - header="'''"), - datarow=partial(_moin_row_with_attrs, "||"), - padding=1, with_header_hide=None), - "html": - TableFormat( - lineabove=_html_begin_table_without_header, - linebelowheader="", - linebetweenrows=None, - linebelow=Line("\n
", "", "", ""), - headerrow=partial(_html_row_with_attrs, "th"), - datarow=partial(_html_row_with_attrs, "td"), - padding=0, with_header_hide=["lineabove"]), - "latex": - TableFormat( - lineabove=_latex_line_begin_tabular, - linebelowheader=Line("\\hline", "", "", ""), - linebetweenrows=None, - linebelow=Line("\\hline\n\\end{tabular}", "", "", ""), - headerrow=_latex_row, - datarow=_latex_row, - padding=1, with_header_hide=None), - "latex_raw": - TableFormat( - lineabove=_latex_line_begin_tabular, - linebelowheader=Line("\\hline", "", "", ""), - linebetweenrows=None, - linebelow=Line("\\hline\n\\end{tabular}", "", "", ""), - headerrow=partial(_latex_row, escrules={}), - datarow=partial(_latex_row, escrules={}), - padding=1, with_header_hide=None), - "latex_booktabs": - TableFormat( - lineabove=partial(_latex_line_begin_tabular, - booktabs=True), - linebelowheader=Line("\\midrule", "", "", ""), - linebetweenrows=None, - linebelow=Line("\\bottomrule\n\\end{tabular}", "", "", - ""), - headerrow=_latex_row, - datarow=_latex_row, - padding=1, with_header_hide=None), - "textile": - TableFormat( - lineabove=None, linebelowheader=None, - linebetweenrows=None, linebelow=None, - headerrow=DataRow("|_. ", "|_.", "|"), - datarow=_textile_row_with_attrs, - padding=1, with_header_hide=None)} - - -tabulate_formats = list(sorted(_table_formats.keys())) - - -# ANSI color codes -_invisible_codes = re.compile(r"\x1b\[\d+[;\d]*m|\x1b\[\d*\;\d*\;\d*m") -_invisible_codes_bytes = re.compile(b"\x1b\[\d+[;\d]*m|\x1b\[\d*\;\d*\;\d*m") - - -def _isconvertible(conv, string): - try: - n = conv(string) - return True - except (ValueError, TypeError): - return False - - -def _isnumber(string): - """ - >>> _isnumber("123.45") - True - >>> _isnumber("123") - True - >>> _isnumber("spam") - False - """ - return _isconvertible(float, string) - - -def _isint(string, inttype=int): - """ - >>> _isint("123") - True - >>> _isint("123.45") - False - """ - return type(string) is inttype or\ - (isinstance(string, _binary_type) or isinstance(string, _text_type))\ - and\ - _isconvertible(inttype, string) - - -def _isbool(string): - """ - >>> _isbool(True) - True - >>> _isbool("False") - True - >>> _isbool(1) - False - """ - return type(string) is _bool_type or\ - (isinstance(string, (_binary_type, _text_type)) and - string in ("True", "False")) - - -def _type(string, has_invisible=True, numparse=True): - """The least generic type (type(None), int, float, str, unicode). - - >>> _type(None) is type(None) - True - >>> _type("foo") is type("") - True - >>> _type("1") is type(1) - True - >>> _type('\x1b[31m42\x1b[0m') is type(42) - True - >>> _type('\x1b[31m42\x1b[0m') is type(42) - True - - """ - - if has_invisible and \ - (isinstance(string, _text_type) or isinstance(string, _binary_type)): - string = _strip_invisible(string) - - if string is None: - return _none_type - elif hasattr(string, "isoformat"): # datetime.datetime, date, and time - return _text_type - elif _isbool(string): - return _bool_type - elif _isint(string) and numparse: - return int - elif _isint(string, _long_type) and numparse: - return int - elif _isnumber(string) and numparse: - return float - elif isinstance(string, _binary_type): - return _binary_type - else: - return _text_type - - -def _afterpoint(string): - """Symbols after a decimal point, -1 if the string lacks the decimal point. - - >>> _afterpoint("123.45") - 2 - >>> _afterpoint("1001") - -1 - >>> _afterpoint("eggs") - -1 - >>> _afterpoint("123e45") - 2 - - """ - if _isnumber(string): - if _isint(string): - return -1 - else: - pos = string.rfind(".") - pos = string.lower().rfind("e") if pos < 0 else pos - if pos >= 0: - return len(string) - pos - 1 - else: - return -1 # no point - else: - return -1 # not a number - - -def _padleft(width, s): - """Flush right. - - >>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430' - True - - """ - fmt = "{0:>%ds}" % width - return fmt.format(s) - - -def _padright(width, s): - """Flush left. - - >>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 ' - True - - """ - fmt = "{0:<%ds}" % width - return fmt.format(s) - - -def _padboth(width, s): - """Center string. - - >>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 ' - True - - """ - fmt = "{0:^%ds}" % width - return fmt.format(s) - - -def _strip_invisible(s): - "Remove invisible ANSI color codes." - if isinstance(s, _text_type): - return re.sub(_invisible_codes, "", s) - else: # a bytestring - return re.sub(_invisible_codes_bytes, "", s) - - -def _visible_width(s): - """Visible width of a printed string. ANSI color codes are removed. - - >>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world") - (5, 5) - - """ - # optional wide-character support - if wcwidth is not None and WIDE_CHARS_MODE: - len_fn = wcwidth.wcswidth - else: - len_fn = len - if isinstance(s, _text_type) or isinstance(s, _binary_type): - return len_fn(_strip_invisible(s)) - else: - return len_fn(_text_type(s)) - - -def _align_column(strings, alignment, minwidth=0, has_invisible=True): - """[string] -> [padded_string] - - >>> list(map(str,_align_column( - ... ["12.345", "-1234.5", "1.23", "1234.5", "1e+234", "1.0e234"], - ... "decimal"))) - [' 12.345 ', '-1234.5 ', ' 1.23 ', ' 1234.5 ', ' 1e+234 ', ' 1.0e234'] - - >>> list(map(str,_align_column(['123.4', '56.7890'], None))) - ['123.4', '56.7890'] - - """ - if alignment == "right": - if not PRESERVE_WHITESPACE: - strings = [s.strip() for s in strings] - padfn = _padleft - elif alignment == "center": - if not PRESERVE_WHITESPACE: - strings = [s.strip() for s in strings] - padfn = _padboth - elif alignment == "decimal": - if has_invisible: - decimals = [_afterpoint(_strip_invisible(s)) for s in strings] - else: - decimals = [_afterpoint(s) for s in strings] - maxdecimals = max(decimals) - strings = [s + (maxdecimals - decs) * " " - for s, decs in zip(strings, decimals)] - padfn = _padleft - elif not alignment: - return strings - else: - if not PRESERVE_WHITESPACE: - strings = [s.strip() for s in strings] - padfn = _padright - - enable_widechars = wcwidth is not None and WIDE_CHARS_MODE - if has_invisible: - width_fn = _visible_width - elif enable_widechars: # optional wide-character support if available - width_fn = wcwidth.wcswidth - else: - width_fn = len - - s_lens = list(map(len, strings)) - s_widths = list(map(width_fn, strings)) - maxwidth = max(max(s_widths), minwidth) - if not enable_widechars and not has_invisible: - padded_strings = [padfn(maxwidth, s) for s in strings] - else: - # enable wide-character width corrections - visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)] - # wcswidth and _visible_width don't count invisible characters; - # padfn doesn't need to apply another correction - padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)] - return padded_strings - - -def _more_generic(type1, type2): - types = {_none_type: 0, _bool_type: 1, int: 2, float: 3, _binary_type: 4, - _text_type: 5} - invtypes = {5: _text_type, 4: _binary_type, 3: float, 2: int, - 1: _bool_type, 0: _none_type} - moregeneric = max(types.get(type1, 5), types.get(type2, 5)) - return invtypes[moregeneric] - - -def _column_type(strings, has_invisible=True, numparse=True): - """The least generic type all column values are convertible to. - - >>> _column_type([True, False]) is _bool_type - True - >>> _column_type(["1", "2"]) is _int_type - True - >>> _column_type(["1", "2.3"]) is _float_type - True - >>> _column_type(["1", "2.3", "four"]) is _text_type - True - >>> _column_type(["four", '\u043f\u044f\u0442\u044c']) is _text_type - True - >>> _column_type([None, "brux"]) is _text_type - True - >>> _column_type([1, 2, None]) is _int_type - True - >>> import datetime as dt - >>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is _text_type - True - - """ - types = [_type(s, has_invisible, numparse) for s in strings] - return reduce(_more_generic, types, _bool_type) - - -def _format(val, valtype, floatfmt, missingval="", has_invisible=True): - """Format a value accoding to its type. - - Unicode is supported: - - >>> hrow = ['\u0431\u0443\u043a\u0432\u0430', - ... '\u0446\u0438\u0444\u0440\u0430'] - >>> tbl = [['\u0430\u0437', 2], ['\u0431\u0443\u043a\u0438', 4]] - >>> good_result = ('\\u0431\\u0443\\u043a\\u0432\\u0430 ' - ... '\\u0446\\u0438\\u0444\\u0440\\u0430\\n------- ' - ... '-------\\n\\u0430\\u0437 ' - ... '2\\n\\u0431\\u0443\\u043a\\u0438 4') - >>> tabulate(tbl, headers=hrow) == good_result - True - - """ - if val is None: - return missingval - - if valtype in [int, _text_type]: - return "{0}".format(val) - elif valtype is _binary_type: - try: - return _text_type(val, "ascii") - except TypeError: - return _text_type(val) - elif valtype is float: - is_a_colored_number = (has_invisible and - isinstance(val, (_text_type, _binary_type))) - if is_a_colored_number: - raw_val = _strip_invisible(val) - formatted_val = format(float(raw_val), floatfmt) - return val.replace(raw_val, formatted_val) - else: - return format(float(val), floatfmt) - else: - return "{0}".format(val) - - -def _align_header(header, alignment, width, visible_width): - """Pad string header to width chars given known visible_width of the - header.""" - width += len(header) - visible_width - if alignment == "left": - return _padright(width, header) - elif alignment == "center": - return _padboth(width, header) - elif not alignment: - return "{0}".format(header) - else: - return _padleft(width, header) - - -def _prepend_row_index(rows, index): - """Add a left-most index column.""" - if index is None or index is False: - return rows - if len(index) != len(rows): - print('index=', index) - print('rows=', rows) - raise ValueError('index must be as long as the number of data rows') - rows = [[v] + list(row) for v, row in zip(index, rows)] - return rows - - -def _bool(val): - """A wrapper around standard bool() which doesn't throw on NumPy - arrays.""" - try: - return bool(val) - except ValueError: # val is likely to be a numpy array with many elements - return False - - -def _normalize_tabular_data(tabular_data, headers, showindex="default"): - """Transform a supported data type to a list of lists, and a list of - headers. - - Supported tabular data types: - - * list-of-lists or another iterable of iterables - - * list of named tuples (usually used with headers="keys") - - * list of dicts (usually used with headers="keys") - - * list of OrderedDicts (usually used with headers="keys") - - * 2D NumPy arrays - - * NumPy record arrays (usually used with headers="keys") - - * dict of iterables (usually used with headers="keys") - - * pandas.DataFrame (usually used with headers="keys") - - The first row can be used as headers if headers="firstrow", - column indices can be used as headers if headers="keys". - - If showindex="default", show row indices of the pandas.DataFrame. - If showindex="always", show row indices for all types of data. - If showindex="never", don't show row indices for all types of data. - If showindex is an iterable, show its values as row indices. - - """ - - try: - bool(headers) - is_headers2bool_broken = False - except ValueError: # numpy.ndarray, pandas.core.index.Index, ... - is_headers2bool_broken = True - headers = list(headers) - - index = None - if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"): - # dict-like and pandas.DataFrame? - if hasattr(tabular_data.values, "__call__"): - # likely a conventional dict - keys = tabular_data.keys() - # columns have to be transposed - rows = list(izip_longest(*tabular_data.values())) - elif hasattr(tabular_data, "index"): - # values is a property, has .index => it's likely a - # pandas.DataFrame (pandas 0.11.0) - keys = list(tabular_data) - if tabular_data.index.name is not None: - if isinstance(tabular_data.index.name, list): - keys[:0] = tabular_data.index.name - else: - keys[:0] = [tabular_data.index.name] - # values matrix doesn't need to be transposed - vals = tabular_data.values - # for DataFrames add an index per default - index = list(tabular_data.index) - rows = [list(row) for row in vals] - else: - raise ValueError( - "tabular data doesn't appear to be a dict or a DataFrame") - - if headers == "keys": - headers = list(map(_text_type, keys)) # headers should be strings - - else: # it's a usual an iterable of iterables, or a NumPy array - rows = list(tabular_data) - - if (headers == "keys" and not rows): - # an empty table (issue #81) - headers = [] - elif (headers == "keys" and - hasattr(tabular_data, "dtype") and - getattr(tabular_data.dtype, "names")): - # numpy record array - headers = tabular_data.dtype.names - elif (headers == "keys" - and len(rows) > 0 - and isinstance(rows[0], tuple) - and hasattr(rows[0], "_fields")): - # namedtuple - headers = list(map(_text_type, rows[0]._fields)) - elif (len(rows) > 0 - and isinstance(rows[0], dict)): - # dict or OrderedDict - uniq_keys = set() # implements hashed lookup - keys = [] # storage for set - if headers == "firstrow": - firstdict = rows[0] if len(rows) > 0 else {} - keys.extend(firstdict.keys()) - uniq_keys.update(keys) - rows = rows[1:] - for row in rows: - for k in row.keys(): - # Save unique items in input order - if k not in uniq_keys: - keys.append(k) - uniq_keys.add(k) - if headers == 'keys': - headers = keys - elif isinstance(headers, dict): - # a dict of headers for a list of dicts - headers = [headers.get(k, k) for k in keys] - headers = list(map(_text_type, headers)) - elif headers == "firstrow": - if len(rows) > 0: - headers = [firstdict.get(k, k) for k in keys] - headers = list(map(_text_type, headers)) - else: - headers = [] - elif headers: - raise ValueError( - 'headers for a list of dicts is not a dict or a keyword') - rows = [[row.get(k) for k in keys] for row in rows] - - elif (headers == "keys" - and hasattr(tabular_data, "description") - and hasattr(tabular_data, "fetchone") - and hasattr(tabular_data, "rowcount")): - # Python Database API cursor object (PEP 0249) - # print tabulate(cursor, headers='keys') - headers = [column[0] for column in tabular_data.description] - - elif headers == "keys" and len(rows) > 0: - # keys are column indices - headers = list(map(_text_type, range(len(rows[0])))) - - # take headers from the first row if necessary - if headers == "firstrow" and len(rows) > 0: - if index is not None: - headers = [index[0]] + list(rows[0]) - index = index[1:] - else: - headers = rows[0] - headers = list(map(_text_type, headers)) # headers should be strings - rows = rows[1:] - - headers = list(map(_text_type, headers)) - rows = list(map(list, rows)) - - # add or remove an index column - showindex_is_a_str = type(showindex) in [_text_type, _binary_type] - if showindex == "default" and index is not None: - rows = _prepend_row_index(rows, index) - elif isinstance(showindex, Iterable) and not showindex_is_a_str: - rows = _prepend_row_index(rows, list(showindex)) - elif (showindex == "always" or - (_bool(showindex) and not showindex_is_a_str)): - if index is None: - index = list(range(len(rows))) - rows = _prepend_row_index(rows, index) - elif (showindex == "never" or - (not _bool(showindex) and not showindex_is_a_str)): - pass - - # pad with empty headers for initial columns if necessary - if headers and len(rows) > 0: - nhs = len(headers) - ncols = len(rows[0]) - if nhs < ncols: - headers = [""] * (ncols - nhs) + headers - - return rows, headers - - -def tabulate(tabular_data, headers=(), tablefmt="simple", - floatfmt=_DEFAULT_FLOATFMT, numalign="decimal", stralign="left", - missingval=_DEFAULT_MISSINGVAL, showindex="default", - disable_numparse=False): - """Format a fixed width table for pretty printing. - - >>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]])) - --- --------- - 1 2.34 - -56 8.999 - 2 10001 - --- --------- - - The first required argument (`tabular_data`) can be a - list-of-lists (or another iterable of iterables), a list of named - tuples, a dictionary of iterables, an iterable of dictionaries, - a two-dimensional NumPy array, NumPy record array, or a Pandas' - dataframe. - - - Table headers - ------------- - - To print nice column headers, supply the second argument (`headers`): - - - `headers` can be an explicit list of column headers - - if `headers="firstrow"`, then the first row of data is used - - if `headers="keys"`, then dictionary keys or column indices are used - - Otherwise a headerless table is produced. - - If the number of headers is less than the number of columns, they - are supposed to be names of the last columns. This is consistent - with the plain-text format of R and Pandas' dataframes. - - >>> print(tabulate([["sex","age"],["Alice","F",24],["Bob","M",19]], - ... headers="firstrow")) - sex age - ----- ----- ----- - Alice F 24 - Bob M 19 - - By default, pandas.DataFrame data have an additional column called - row index. To add a similar column to all other types of data, - use `showindex="always"` or `showindex=True`. To suppress row indices - for all types of data, pass `showindex="never" or `showindex=False`. - To add a custom row index column, pass `showindex=some_iterable`. - - >>> print(tabulate([["F",24],["M",19]], showindex="always")) - - - -- - 0 F 24 - 1 M 19 - - - -- - - - Column alignment - ---------------- - - `tabulate` tries to detect column types automatically, and aligns - the values properly. By default it aligns decimal points of the - numbers (or flushes integer numbers to the right), and flushes - everything else to the left. Possible column alignments - (`numalign`, `stralign`) are: "right", "center", "left", "decimal" - (only for `numalign`), and None (to disable alignment). - - - Table formats - ------------- - - `floatfmt` is a format specification used for columns which - contain numeric data with a decimal point. This can also be - a list or tuple of format strings, one per column. - - `None` values are replaced with a `missingval` string (like - `floatfmt`, this can also be a list of values for different - columns): - - >>> print(tabulate([["spam", 1, None], - ... ["eggs", 42, 3.14], - ... ["other", None, 2.7]], missingval="?")) - ----- -- ---- - spam 1 ? - eggs 42 3.14 - other ? 2.7 - ----- -- ---- - - Various plain-text table formats (`tablefmt`) are supported: - 'plain', 'simple', 'grid', 'pipe', 'orgtbl', 'rst', 'mediawiki', - 'latex', 'latex_raw' and 'latex_booktabs'. Variable `tabulate_formats` - contains the list of currently supported formats. - - "plain" format doesn't use any pseudographics to draw tables, - it separates columns with a double space: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "plain")) - strings numbers - spam 41.9999 - eggs 451 - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="plain")) - spam 41.9999 - eggs 451 - - "simple" format is like Pandoc simple_tables: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "simple")) - strings numbers - --------- --------- - spam 41.9999 - eggs 451 - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="simple")) - ---- -------- - spam 41.9999 - eggs 451 - ---- -------- - - "grid" is similar to tables produced by Emacs table.el package or - Pandoc grid_tables: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "grid")) - +-----------+-----------+ - | strings | numbers | - +===========+===========+ - | spam | 41.9999 | - +-----------+-----------+ - | eggs | 451 | - +-----------+-----------+ - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="grid")) - +------+----------+ - | spam | 41.9999 | - +------+----------+ - | eggs | 451 | - +------+----------+ - - "fancy_grid" draws a grid using box-drawing characters: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "fancy_grid")) - ╒═══════════╤═══════════╕ - │ strings │ numbers │ - ╞═══════════╪═══════════╡ - │ spam │ 41.9999 │ - ├───────────┼───────────┤ - │ eggs │ 451 │ - ╘═══════════╧═══════════╛ - - "pipe" is like tables in PHP Markdown Extra extension or Pandoc - pipe_tables: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "pipe")) - | strings | numbers | - |:----------|----------:| - | spam | 41.9999 | - | eggs | 451 | - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="pipe")) - |:-----|---------:| - | spam | 41.9999 | - | eggs | 451 | - - "orgtbl" is like tables in Emacs org-mode and orgtbl-mode. They - are slightly different from "pipe" format by not using colons to - define column alignment, and using a "+" sign to indicate line - intersections: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "orgtbl")) - | strings | numbers | - |-----------+-----------| - | spam | 41.9999 | - | eggs | 451 | - - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="orgtbl")) - | spam | 41.9999 | - | eggs | 451 | - - "rst" is like a simple table format from reStructuredText; please - note that reStructuredText accepts also "grid" tables: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... ["strings", "numbers"], "rst")) - ========= ========= - strings numbers - ========= ========= - spam 41.9999 - eggs 451 - ========= ========= - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], tablefmt="rst")) - ==== ======== - spam 41.9999 - eggs 451 - ==== ======== - - "mediawiki" produces a table markup used in Wikipedia and on other - MediaWiki-based sites: - - >>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], - ... ["eggs", "451.0"]], headers="firstrow", - ... tablefmt="mediawiki")) - {| class="wikitable" style="text-align: left;" - |+ - |- - ! strings !! align="right"| numbers - |- - | spam || align="right"| 41.9999 - |- - | eggs || align="right"| 451 - |} - - "html" produces HTML markup: - - >>> print(tabulate([["strings", "numbers"], ["spam", 41.9999], - ... ["eggs", "451.0"]], headers="firstrow", - ... tablefmt="html")) - - - - - - - - -
strings numbers
spam 41.9999
eggs 451
- - "latex" produces a tabular environment of LaTeX document markup: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="latex")) - \\begin{tabular}{lr} - \\hline - spam & 41.9999 \\\\ - eggs & 451 \\\\ - \\hline - \\end{tabular} - - "latex_raw" is similar to "latex", but doesn't escape special characters, - such as backslash and underscore, so LaTeX commands may embedded into - cells' values: - - >>> print(tabulate([["spam$_9$", 41.9999], ["\\\\emph{eggs}", "451.0"]], - ... tablefmt="latex_raw")) - \\begin{tabular}{lr} - \\hline - spam$_9$ & 41.9999 \\\\ - \\emph{eggs} & 451 \\\\ - \\hline - \\end{tabular} - - "latex_booktabs" produces a tabular environment of LaTeX document markup - using the booktabs.sty package: - - >>> print(tabulate([["spam", 41.9999], ["eggs", "451.0"]], - ... tablefmt="latex_booktabs")) - \\begin{tabular}{lr} - \\toprule - spam & 41.9999 \\\\ - eggs & 451 \\\\ - \\bottomrule - \end{tabular} - - Number parsing - -------------- - By default, anything which can be parsed as a number is a number. - This ensures numbers represented as strings are aligned properly. - This can lead to weird results for particular strings such as - specific git SHAs e.g. "42992e1" will be parsed into the number - 429920 and aligned as such. - - To completely disable number parsing (and alignment), use - `disable_numparse=True`. For more fine grained control, a list column - indices is used to disable number parsing only on those columns - e.g. `disable_numparse=[0, 2]` would disable number parsing only on the - first and third columns. - - """ - if tabular_data is None: - tabular_data = [] - list_of_lists, headers = _normalize_tabular_data( - tabular_data, headers, showindex=showindex) - - # empty values in the first column of RST tables should be escaped - # (issue #82). "" should be escaped as "\\ " or ".." - if tablefmt == 'rst': - list_of_lists, headers = _rst_escape_first_column(list_of_lists, - headers) - - # optimization: look for ANSI control codes once, - # enable smart width functions only if a control code is found - plain_text = '\n'.join(['\t'.join(map(_text_type, headers))] + - ['\t'.join(map(_text_type, row)) - for row in list_of_lists]) - - has_invisible = re.search(_invisible_codes, plain_text) - enable_widechars = wcwidth is not None and WIDE_CHARS_MODE - if has_invisible: - width_fn = _visible_width - elif enable_widechars: # optional wide-character support if available - width_fn = wcwidth.wcswidth - else: - width_fn = len - - # format rows and columns, convert numeric values to strings - cols = list(izip_longest(*list_of_lists)) - numparses = _expand_numparse(disable_numparse, len(cols)) - coltypes = [_column_type(col, numparse=np) for col, np in - zip(cols, numparses)] - if isinstance(floatfmt, basestring): # old version - # just duplicate the string to use in each column - float_formats = len(cols) * [floatfmt] - else: # if floatfmt is list, tuple etc we have one per column - float_formats = list(floatfmt) - if len(float_formats) < len(cols): - float_formats.extend((len(cols) - len(float_formats)) * - [_DEFAULT_FLOATFMT]) - if isinstance(missingval, basestring): - missing_vals = len(cols) * [missingval] - else: - missing_vals = list(missingval) - if len(missing_vals) < len(cols): - missing_vals.extend((len(cols) - len(missing_vals)) * - [_DEFAULT_MISSINGVAL]) - cols = [[_format(v, ct, fl_fmt, miss_v, has_invisible) for v in c] - for c, ct, fl_fmt, miss_v in zip(cols, coltypes, float_formats, - missing_vals)] - - # align columns - aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] - minwidths = [width_fn(h) + MIN_PADDING - for h in headers] if headers else [0] * len(cols) - cols = [_align_column(c, a, minw, has_invisible) - for c, a, minw in zip(cols, aligns, minwidths)] - - if headers: - # align headers and add headers - t_cols = cols or [['']] * len(headers) - t_aligns = aligns or [stralign] * len(headers) - minwidths = [max(minw, width_fn(c[0])) - for minw, c in zip(minwidths, t_cols)] - headers = [_align_header(h, a, minw, width_fn(h)) - for h, a, minw in zip(headers, t_aligns, minwidths)] - rows = list(zip(*cols)) - else: - minwidths = [width_fn(c[0]) for c in cols] - rows = list(zip(*cols)) - - if not isinstance(tablefmt, TableFormat): - tablefmt = _table_formats.get(tablefmt, _table_formats["simple"]) - - return _format_table(tablefmt, headers, rows, minwidths, aligns) - - -def _expand_numparse(disable_numparse, column_count): - """Return a list of bools of length `column_count` which indicates whether - number parsing should be used on each column. - - If `disable_numparse` is a list of indices, each of those indices - are False, and everything else is True. If `disable_numparse` is a - bool, then the returned list is all the same. - - """ - if isinstance(disable_numparse, Iterable): - numparses = [True] * column_count - for index in disable_numparse: - numparses[index] = False - return numparses - else: - return [not disable_numparse] * column_count - - -def _build_simple_row(padded_cells, rowfmt): - "Format row according to DataRow format without padding." - begin, sep, end = rowfmt - return (begin + sep.join(padded_cells) + end).rstrip() - - -def _build_row(padded_cells, colwidths, colaligns, rowfmt): - "Return a string which represents a row of data cells." - if not rowfmt: - return None - if hasattr(rowfmt, "__call__"): - return rowfmt(padded_cells, colwidths, colaligns) - else: - return _build_simple_row(padded_cells, rowfmt) - - -def _build_line(colwidths, colaligns, linefmt): - "Return a string which represents a horizontal line." - if not linefmt: - return None - if hasattr(linefmt, "__call__"): - return linefmt(colwidths, colaligns) - else: - begin, fill, sep, end = linefmt - cells = [fill*w for w in colwidths] - return _build_simple_row(cells, (begin, sep, end)) - - -def _pad_row(cells, padding): - if cells: - pad = " "*padding - padded_cells = [pad + cell + pad for cell in cells] - return padded_cells - else: - return cells - - -def _format_table(fmt, headers, rows, colwidths, colaligns): - """Produce a plain-text representation of the table.""" - lines = [] - hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] - pad = fmt.padding - headerrow = fmt.headerrow - - padded_widths = [(w + 2*pad) for w in colwidths] - padded_headers = _pad_row(headers, pad) - padded_rows = [_pad_row(row, pad) for row in rows] - - if fmt.lineabove and "lineabove" not in hidden: - lines.append(_build_line(padded_widths, colaligns, fmt.lineabove)) - - if padded_headers: - lines.append(_build_row(padded_headers, padded_widths, colaligns, - headerrow)) - if fmt.linebelowheader and "linebelowheader" not in hidden: - lines.append(_build_line(padded_widths, colaligns, - fmt.linebelowheader)) - - if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: - # initial rows with a line below - for row in padded_rows[:-1]: - lines.append(_build_row(row, padded_widths, colaligns, - fmt.datarow)) - lines.append(_build_line(padded_widths, colaligns, - fmt.linebetweenrows)) - # the last row without a line below - lines.append(_build_row(padded_rows[-1], padded_widths, colaligns, - fmt.datarow)) - else: - for row in padded_rows: - lines.append(_build_row(row, padded_widths, colaligns, - fmt.datarow)) - - if fmt.linebelow and "linebelow" not in hidden: - lines.append(_build_line(padded_widths, colaligns, fmt.linebelow)) - - if headers or rows: - return "\n".join(lines) - else: # a completely empty table - return "" - - -def _main(): - """\ Usage: tabulate [options] [FILE ...] - - Pretty-print tabular data. - See also https://bitbucket.org/astanin/python-tabulate - - FILE a filename of the file with tabular data; - if "-" or missing, read data from stdin. - - Options: - - -h, --help show this message - -1, --header use the first row of data as a table header - -o FILE, --output FILE print table to FILE (default: stdout) - -s REGEXP, --sep REGEXP use a custom column separator (default: whitespace) - -F FPFMT, --float FPFMT floating point number format (default: g) - -f FMT, --format FMT set output table format; supported formats: - plain, simple, grid, fancy_grid, pipe, orgtbl, - rst, mediawiki, html, latex, latex_raw, - latex_booktabs, tsv - (default: simple) - - """ - import getopt - import sys - import textwrap - usage = textwrap.dedent(_main.__doc__) - try: - opts, args = getopt.getopt( - sys.argv[1:], "h1o:s:F:f:", - ["help", "header", "output", "sep=", "float=", "format="]) - except getopt.GetoptError as e: - print(e) - print(usage) - sys.exit(2) - headers = [] - floatfmt = _DEFAULT_FLOATFMT - tablefmt = "simple" - sep = r"\s+" - outfile = "-" - for opt, value in opts: - if opt in ["-1", "--header"]: - headers = "firstrow" - elif opt in ["-o", "--output"]: - outfile = value - elif opt in ["-F", "--float"]: - floatfmt = value - elif opt in ["-f", "--format"]: - if value not in tabulate_formats: - print("%s is not a supported table format" % value) - print(usage) - sys.exit(3) - tablefmt = value - elif opt in ["-s", "--sep"]: - sep = value - elif opt in ["-h", "--help"]: - print(usage) - sys.exit(0) - files = [sys.stdin] if not args else args - with (sys.stdout if outfile == "-" else open(outfile, "w")) as out: - for f in files: - if f == "-": - f = sys.stdin - if _is_file(f): - _pprint_file(f, headers=headers, tablefmt=tablefmt, - sep=sep, floatfmt=floatfmt, file=out) - else: - with open(f) as fobj: - _pprint_file(fobj, headers=headers, tablefmt=tablefmt, - sep=sep, floatfmt=floatfmt, file=out) - - -def _pprint_file(fobject, headers, tablefmt, sep, floatfmt, file): - rows = fobject.readlines() - table = [re.split(sep, r.rstrip()) for r in rows if r.strip()] - print(tabulate(table, headers, tablefmt, floatfmt=floatfmt), file=file) - - -if __name__ == "__main__": - _main() diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index f01afb1d..3656aa69 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1,5 +1,3 @@ -from __future__ import print_function -from __future__ import unicode_literals import logging from re import compile, escape from collections import Counter @@ -8,34 +6,40 @@ from .packages.completion_engine import suggest_type from .packages.parseutils import last_word -from .packages.special.favoritequeries import favoritequeries +from .packages.filepaths import parse_path, complete_path, suggest_path +from .packages.special.favoritequeries import FavoriteQueries _logger = logging.getLogger(__name__) class SQLCompleter(Completer): keywords = ['ACCESS', 'ADD', 'ALL', 'ALTER TABLE', 'AND', 'ANY', 'AS', - 'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', 'BINARY', 'BY', - 'CASE', 'CHAR', 'CHECK', 'COLUMN', 'COMMENT', 'COMMIT', 'CONSTRAINT', - 'CHANGE MASTER TO', 'CHARACTER SET', 'COLLATE', 'CREATE', 'CURRENT', 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', - 'DECIMAL', 'DEFAULT', 'DELETE FROM', 'DELIMITER', 'DESC', - 'DESCRIBE', 'DROP', 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', - 'FILE', 'FLOAT', 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION', 'GRANT', - 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', 'INCREMENT', 'INDEX', - 'INSERT INTO', 'INTEGER', 'INTO', 'INTERVAL', 'IS', 'JOIN', 'KEY', 'LEFT', - 'LEVEL', 'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER', 'MODE', - 'MODIFY', 'NOT', 'NULL', 'NUMBER', 'OFFSET', 'ON', 'OPTION', 'OR', - 'ORDER BY', 'OUTER', 'OWNER', 'PASSWORD', 'PORT', 'PRIMARY', - 'PRIVILEGES', 'PROCESSLIST', 'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET', - 'REVOKE', 'RIGHT', 'ROLLBACK','ROW', 'ROWS', 'ROW_FORMAT', 'SELECT', 'SESSION', 'SET', - 'SHARE', 'SHOW', 'SLAVE', 'SMALLINT', 'START', 'STOP', 'TABLE', 'THEN', - 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE', 'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', - 'USE', 'USER', 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', - 'WITH'] - - functions = ['AVG', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT', 'LAST', - 'LCASE', 'LEN', 'MAX', 'MIN', 'MID', 'NOW', 'ROUND', 'SUM', - 'TOP', 'UCASE'] + 'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', + 'BIGINT', 'BINARY', 'BY', 'CASE', 'CHANGE MASTER TO', 'CHAR', + 'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT', + 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT', + 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT', + 'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP', + 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT', + 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION', + 'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', + 'INCREMENT', 'INDEX', 'INSERT INTO', 'INT', 'INTEGER', + 'INTERVAL', 'INTO', 'IS', 'JOIN', 'KEY', 'LEFT', 'LEVEL', + 'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER', + 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER', + 'OFFSET', 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER', + 'PASSWORD', 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST', + 'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET', + 'REVOKE', 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT', + 'SAVEPOINT', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SHOW', + 'SLAVE', 'SMALLINT', 'SMALLINT', 'START', 'STOP', 'TABLE', + 'THEN', 'TINYINT', 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE', + 'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', 'USE', 'USER', + 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', 'WITH'] + + functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT', + 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID', + 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', 'UNIX_TIMESTAMP'] show_items = [] @@ -49,23 +53,26 @@ class SQLCompleter(Completer): users = [] - def __init__(self, smart_completion=True, supported_formats=()): + def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'): super(self.__class__, self).__init__() self.smart_completion = smart_completion self.reserved_words = set() for x in self.keywords: self.reserved_words.update(x.split()) - self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$") self.special_commands = [] self.table_formats = supported_formats + if keyword_casing not in ('upper', 'lower', 'auto'): + keyword_casing = 'auto' + self.keyword_casing = keyword_casing self.reset_completions() def escape_name(self, name): if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): - name = '`%s`' % name + name = '`%s`' % name return name @@ -195,7 +202,7 @@ def reset_completions(self): self.all_completions = set(self.keywords + self.functions) @staticmethod - def find_matches(text, collection, start_only=False, fuzzy=True): + def find_matches(text, collection, start_only=False, fuzzy=True, casing=None): """Find completion matches for the given text. Given the user's input text and a collection of available @@ -209,7 +216,8 @@ def find_matches(text, collection, start_only=False, fuzzy=True): yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ - text = last_word(text, include='most_punctuations').lower() + last = last_word(text, include='most_punctuations') + text = last.lower() completions = [] @@ -227,7 +235,16 @@ def find_matches(text, collection, start_only=False, fuzzy=True): if match_point >= 0: completions.append((len(text), match_point, item)) - return (Completion(z, -len(text)) for x, y, z in sorted(completions)) + if casing == 'auto': + casing = 'lower' if last and last[-1].islower() else 'upper' + + def apply_case(kw): + if casing == 'upper': + return kw.upper() + return kw.lower() + + return (Completion(z if casing is None else apply_case(z), -len(text)) + for x, y, z in sorted(completions)) def get_completions(self, document, complete_event, smart_completion=None): word_before_cursor = document.get_word_before_cursor(WORD=True) @@ -278,7 +295,8 @@ def get_completions(self, document, complete_event, smart_completion=None): predefined_funcs = self.find_matches(word_before_cursor, self.functions, start_only=True, - fuzzy=False) + fuzzy=False, + casing=self.keyword_casing) completions.extend(predefined_funcs) elif suggestion['type'] == 'table': @@ -305,14 +323,16 @@ def get_completions(self, document, complete_event, smart_completion=None): elif suggestion['type'] == 'keyword': keywords = self.find_matches(word_before_cursor, self.keywords, start_only=True, - fuzzy=False) + fuzzy=False, + casing=self.keyword_casing) completions.extend(keywords) elif suggestion['type'] == 'show': show_items = self.find_matches(word_before_cursor, self.show_items, start_only=False, - fuzzy=True) + fuzzy=True, + casing=self.keyword_casing) completions.extend(show_items) elif suggestion['type'] == 'change': @@ -335,7 +355,7 @@ def get_completions(self, document, complete_event, smart_completion=None): completions.extend(special) elif suggestion['type'] == 'favoritequery': queries = self.find_matches(word_before_cursor, - favoritequeries.list(), + FavoriteQueries.instance.list(), start_only=False, fuzzy=True) completions.extend(queries) elif suggestion['type'] == 'table_format': @@ -343,9 +363,26 @@ def get_completions(self, document, complete_event, smart_completion=None): self.table_formats, start_only=True, fuzzy=False) completions.extend(formats) + elif suggestion['type'] == 'file_name': + file_names = self.find_files(word_before_cursor) + completions.extend(file_names) return completions + def find_files(self, word): + """Yield matching directory or file names. + + :param word: + :return: iterable + + """ + base_path, last_path, position = parse_path(word) + paths = suggest_path(word) + for name in sorted(paths): + suggestion = complete_path(name, last_path) + if suggestion: + yield Completion(suggestion, position) + def populate_scoped_cols(self, scoped_tbls): """Find all columns in a set of scoped_tables :param scoped_tbls: list of (schema, table, alias) tuples diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 4d5c0a09..94614387 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,23 +1,90 @@ +import enum import logging +import re + import pymysql -import sqlparse from .packages import special from pymysql.constants import FIELD_TYPE -from pymysql.converters import (convert_mysql_timestamp, convert_datetime, - convert_timedelta, convert_date, conversions) +from pymysql.converters import (convert_datetime, + convert_timedelta, convert_date, conversions, + decoders) +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko _logger = logging.getLogger(__name__) +FIELD_TYPES = decoders.copy() +FIELD_TYPES.update({ + FIELD_TYPE.NULL: type(None) +}) + + +ERROR_CODE_ACCESS_DENIED = 1045 + + +class ServerSpecies(enum.Enum): + MySQL = 'MySQL' + MariaDB = 'MariaDB' + Percona = 'Percona' + Unknown = 'MySQL' + + +class ServerInfo: + def __init__(self, species, version_str): + self.species = species + self.version_str = version_str + self.version = self.calc_mysql_version_value(version_str) + + @staticmethod + def calc_mysql_version_value(version_str) -> int: + if not version_str or not isinstance(version_str, str): + return 0 + try: + major, minor, patch = version_str.split('.') + except ValueError: + return 0 + else: + return int(major) * 10_000 + int(minor) * 100 + int(patch) + + @classmethod + def from_version_string(cls, version_string): + if not version_string: + return cls(ServerSpecies.Unknown, '') + + re_species = ( + (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', + ServerSpecies.Percona), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', + ServerSpecies.MySQL), + ) + for regexp, species in re_species: + match = re.search(regexp, version_string) + if match is not None: + parsed_version = match.group('version') + detected_species = species + break + else: + detected_species = ServerSpecies.Unknown + parsed_version = '' + + return cls(detected_species, parsed_version) + + def __str__(self): + if self.species: + return f'{self.species.value} {self.version_str}' + else: + return self.version_str + + class SQLExecute(object): databases_query = '''SHOW DATABASES''' tables_query = '''SHOW TABLES''' - version_query = '''SELECT @@VERSION''' - - version_comment_query = '''SELECT @@VERSION_COMMENT''' - show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' @@ -30,7 +97,8 @@ class SQLExecute(object): order by table_name,ordinal_position''' def __init__(self, database, user, password, host, port, socket, charset, - local_infile, ssl=False): + local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, + ssh_key_filename, init_command=None): self.dbname = database self.user = user self.password = password @@ -40,12 +108,20 @@ def __init__(self, database, user, password, host, port, socket, charset, self.charset = charset self.local_infile = local_infile self.ssl = ssl - self._server_type = None + self.server_info = None self.connection_id = None + self.ssh_user = ssh_user + self.ssh_host = ssh_host + self.ssh_port = ssh_port + self.ssh_password = ssh_password + self.ssh_key_filename = ssh_key_filename + self.init_command = init_command self.connect() def connect(self, database=None, user=None, password=None, host=None, - port=None, socket=None, charset=None, local_infile=None, ssl=None): + port=None, socket=None, charset=None, local_infile=None, + ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, + ssh_password=None, ssh_key_filename=None, init_command=None): db = (database or self.dbname) user = (user or self.user) password = (password or self.password) @@ -55,7 +131,14 @@ def connect(self, database=None, user=None, password=None, host=None, charset = (charset or self.charset) local_infile = (local_infile or self.local_infile) ssl = (ssl or self.ssl) - _logger.debug('Connection DB Params: \n' + ssh_user = (ssh_user or self.ssh_user) + ssh_host = (ssh_host or self.ssh_host) + ssh_port = (ssh_port or self.ssh_port) + ssh_password = (ssh_password or self.ssh_password) + ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) + init_command = (init_command or self.init_command) + _logger.debug( + 'Connection DB Params: \n' '\tdatabase: %r' '\tuser: %r' '\thost: %r' @@ -63,22 +146,57 @@ def connect(self, database=None, user=None, password=None, host=None, '\tsocket: %r' '\tcharset: %r' '\tlocal_infile: %r' - '\tssl: %r', - database, user, host, port, socket, charset, local_infile, ssl) + '\tssl: %r' + '\tssh_user: %r' + '\tssh_host: %r' + '\tssh_port: %r' + '\tssh_password: %r' + '\tssh_key_filename: %r' + '\tinit_command: %r', + db, user, host, port, socket, charset, local_infile, ssl, + ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, + init_command + ) conv = conversions.copy() conv.update({ - FIELD_TYPE.TIMESTAMP: lambda obj: (convert_mysql_timestamp(obj) or obj), + FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), }) - conn = pymysql.connect(database=db, user=user, password=password, - host=host, port=port, unix_socket=socket, - use_unicode=True, charset=charset, autocommit=True, - client_flag=pymysql.constants.CLIENT.INTERACTIVE, - local_infile=local_infile, - conv=conv, ssl=ssl) + defer_connect = False + + if ssh_host: + defer_connect = True + + client_flag = pymysql.constants.CLIENT.INTERACTIVE + if init_command and len(list(special.split_queries(init_command))) > 1: + client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + + conn = pymysql.connect( + database=db, user=user, password=password, host=host, port=port, + unix_socket=socket, use_unicode=True, charset=charset, + autocommit=True, client_flag=client_flag, + local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + defer_connect=defer_connect, init_command=init_command + ) + + if ssh_host: + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + client.connect( + ssh_host, ssh_port, ssh_user, ssh_password, + key_filename=ssh_key_filename + ) + chan = client.get_transport().open_channel( + 'direct-tcpip', + (host, port), + ('0.0.0.0', 0), + ) + conn.connect(chan) + if hasattr(self, 'conn'): self.conn.close() self.conn = conn @@ -92,8 +210,10 @@ def connect(self, database=None, user=None, password=None, host=None, self.socket = socket self.charset = charset self.ssl = ssl + self.init_command = init_command # retrieve connection id self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) def run(self, statement): """Execute the sql in the database and return the results. The results @@ -111,43 +231,48 @@ def run(self, statement): # want to save them all together. if statement.startswith('\\fs'): components = [statement] - else: - components = sqlparse.split(statement) + components = special.split_queries(statement) for sql in components: - # Remove spaces, eol and semi-colons. - sql = sql.rstrip(';') - # \G is treated specially since we have to set the expanded output. if sql.endswith('\\G'): special.set_expanded_output(True) sql = sql[:-2].strip() + + cur = self.conn.cursor() try: # Special command _logger.debug('Trying a dbspecial command. sql: %r', sql) - cur = self.conn.cursor() for result in special.execute(cur, sql): yield result except special.CommandNotFound: # Regular SQL - yield self.execute_normal_sql(sql) - - def execute_normal_sql(self, split_sql): - _logger.debug('Regular sql statement. sql: %r', split_sql) - cur = self.conn.cursor() - num_rows = cur.execute(split_sql) + _logger.debug('Regular sql statement. sql: %r', sql) + cur.execute(sql) + while True: + yield self.get_result(cur) + + # PyMySQL returns an extra, empty result set with stored + # procedures. We skip it (rowcount is zero and no + # description). + if not cur.nextset() or (not cur.rowcount and cur.description is None): + break + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" title = headers = None - # cur.description is not None for queries that return result sets, e.g. - # SELECT or SHOW. - if cur.description is not None: - headers = [x[0] for x in cur.description] + # cursor.description is not None for queries that return result sets, + # e.g. SELECT or SHOW. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] status = '{0} row{1} in set' else: _logger.debug('No rows in result.') status = 'Query OK, {0} row{1} affected' - status = status.format(num_rows, '' if num_rows == 1 else 's') + status = status.format(cursor.rowcount, + '' if cursor.rowcount == 1 else 's') - return (title, cur if cur.description else None, headers, status) + return (title, cursor if cursor.description else None, headers, status) def tables(self): """Yields table names""" @@ -159,7 +284,7 @@ def tables(self): yield row def table_columns(self): - """Yields column names""" + """Yields (table name, column name) pairs""" with self.conn.cursor() as cur: _logger.debug('Columns Query. sql: %r', self.table_columns_query) cur.execute(self.table_columns_query % self.dbname) @@ -186,7 +311,7 @@ def show_candidates(self): _logger.debug('Show Query. sql: %r', self.show_candidates_query) try: cur.execute(self.show_candidates_query) - except pymysql.OperationalError as e: + except pymysql.DatabaseError as e: _logger.error('No show completions due to %r', e) yield '' else: @@ -198,34 +323,13 @@ def users(self): _logger.debug('Users Query. sql: %r', self.users_query) try: cur.execute(self.users_query) - except pymysql.OperationalError as e: + except pymysql.DatabaseError as e: _logger.error('No user completions due to %r', e) yield '' else: for row in cur: yield row - def server_type(self): - if self._server_type: - return self._server_type - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - version = cur.fetchone()[0] - _logger.debug('Version Comment. sql: %r', self.version_comment_query) - cur.execute(self.version_comment_query) - version_comment = cur.fetchone()[0].lower() - - if 'mariadb' in version_comment: - product_type = 'mariadb' - elif 'percona' in version_comment: - product_type = 'percona' - else: - product_type = 'mysql' - - self._server_type = (product_type, version) - return self._server_type - def get_connection_id(self): if not self.connection_id: self.reset_connection_id() @@ -238,3 +342,7 @@ def reset_connection_id(self): for title, cur, headers, status in res: self.connection_id = cur.fetchone()[0] _logger.debug('Current connection id: %s', self.connection_id) + + def change_db(self, db): + self.conn.select_db(db) + self.dbname = db diff --git a/pytest.ini b/pytest.ini index 7c5b52b7..5422131c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -addopts=--capture=sys --showlocals --doctest-modules +addopts = --ignore=mycli/packages/paramiko_stub/__init__.py diff --git a/release.py b/release.py index 81769583..39df8a3a 100755 --- a/release.py +++ b/release.py @@ -1,15 +1,11 @@ -#!/usr/bin/env python -from __future__ import print_function +"""A script to publish a release of mycli to PyPI.""" + +from optparse import OptionParser import re -import ast import subprocess import sys -from optparse import OptionParser -try: - input = raw_input -except NameError: - pass +import click DEBUG = False CONFIRM_STEPS = False @@ -24,9 +20,7 @@ def skip_step(): global CONFIRM_STEPS if CONFIRM_STEPS: - choice = input("--- Confirm step? (y/N) [y] ") - if choice.lower() == 'n': - return True + return not click.confirm('--- Run this step?', default=True) return False @@ -49,11 +43,11 @@ def run_step(*args): def version(version_file): - _version_re = re.compile(r'__version__\s+=\s+(.*)') + _version_re = re.compile( + r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)') - with open(version_file, 'rb') as f: - ver = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) + with open(version_file) as f: + ver = _version_re.search(f.read()).group('version') return ver @@ -61,17 +55,14 @@ def version(version_file): def commit_for_release(version_file, ver): run_step('git', 'reset') run_step('git', 'add', version_file) - run_step('git', 'commit', '--message', 'Releasing version %s' % ver) + run_step('git', 'commit', '--message', + 'Releasing version {}'.format(ver)) def create_git_tag(tag_name): run_step('git', 'tag', tag_name) -def register_with_pypi(): - run_step('python', 'setup.py', 'register') - - def create_distribution_files(): run_step('python', 'setup.py', 'sdist', 'bdist_wheel') @@ -81,7 +72,7 @@ def upload_distribution_files(): def push_to_github(): - run_step('git', 'push', 'origin', 'master') + run_step('git', 'push', 'origin', 'main') def push_tags_to_github(): @@ -90,8 +81,7 @@ def push_tags_to_github(): def checklist(questions): for question in questions: - choice = input(question + ' (y/N) [n] ') - if choice.lower() != 'y': + if not click.confirm('--- {}'.format(question), default=False): sys.exit(1) @@ -99,12 +89,6 @@ def checklist(questions): if DEBUG: subprocess.check_output = lambda x: x - checks = ['Have you created the debian package?', - 'Have you updated the AUTHORS file?', - 'Have you updated the `Usage` section of the README?', - ] - checklist(checks) - ver = version('mycli/__init__.py') print('Releasing Version:', ver) @@ -123,13 +107,11 @@ def checklist(questions): CONFIRM_STEPS = popts.confirm_steps DRY_RUN = popts.dry_run - choice = input('Are you sure? (y/N) [n] ') - if choice.lower() != 'y': + if not click.confirm('Are you sure?', default=False): sys.exit(1) commit_for_release('mycli/__init__.py', ver) - create_git_tag('v%s' % ver) - register_with_pypi() + create_git_tag('v{}'.format(ver)) create_distribution_files() push_to_github() push_tags_to_github() diff --git a/release_procedure.txt b/release_procedure.txt deleted file mode 100644 index 1b935b62..00000000 --- a/release_procedure.txt +++ /dev/null @@ -1,14 +0,0 @@ -# vi: ft=vimwiki - -* Bump the version number in mycli/__init__.py -* Commit with message: 'Releasing version X.X.X.' -* Create a tag: git tag vX.X.X -* Register with pypi for new version: python setup.py register -* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/mysql-cli/master/screenshots/image01.png -* Create source dist tar ball: python setup.py sdist -* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt]. -* Upload the source dist to PyPI: https://pypi.python.org/pypi/mycli -* pip install mycli -* Run SanityChecks. -* Push the version back to github: git push --tags origin master -* Done! diff --git a/requirements-dev.txt b/requirements-dev.txt index b7e6e2dc..9c403160 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,13 @@ -mock -pytest +pytest!=3.3.0 +pytest-cov==2.4.0 tox -twine==1.8.1 -behave -pexpect -coverage==4.3.4 -pep8radius +twine==1.12.1 +behave>=1.2.4 +pexpect==3.3 +coverage==5.0.4 +codecov==2.0.9 +autopep8==1.3.3 +colorama==0.4.1 +git+https://github.com/hayd/pep8radius.git # --error-status option not released +click>=7.0 +paramiko==2.7.1 diff --git a/screenshots/tables.png b/screenshots/tables.png index bb8859a5..1d6afcf2 100644 Binary files a/screenshots/tables.png and b/screenshots/tables.png differ diff --git a/setup.cfg b/setup.cfg index 2a9acf13..e533c7b7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,18 @@ [bdist_wheel] universal = 1 + +[tool:pytest] +addopts = --capture=sys + --showlocals + --doctest-modules + --doctest-ignore-import-errors + --ignore=setup.py + --ignore=mycli/magic.py + --ignore=mycli/packages/parseutils.py + --ignore=test/features + +[pep8] +rev = master +docformatter = True +diff = True +error-status = True diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 33ab724f..f79bcd77 --- a/setup.py +++ b/setup.py @@ -1,27 +1,91 @@ -import re +#!/usr/bin/env python + import ast -import platform -from setuptools import setup, find_packages +import re +import subprocess +import sys + +from setuptools import Command, find_packages, setup +from setuptools.command.test import test as TestCommand _version_re = re.compile(r'__version__\s+=\s+(.*)') -with open('mycli/__init__.py', 'rb') as f: - version = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) +with open('mycli/__init__.py') as f: + version = ast.literal_eval(_version_re.search( + f.read()).group(1)) description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.' install_requirements = [ - 'click >= 4.1', - 'Pygments >= 1.6', - 'prompt_toolkit>=1.0.10,<1.1.0', - 'PyMySQL >= 0.6.7', - 'sqlparse>=0.2.2,<0.3.0', + 'click >= 7.0', + 'cryptography >= 1.0.0', + # 'Pygments>=1.6,<=2.11.1', + 'Pygments>=1.6', + 'prompt_toolkit>=3.0.6,<4.0.0', + 'PyMySQL >= 0.9.2', + 'sqlparse>=0.3.0,<0.5.0', 'configobj >= 5.0.5', - 'pycryptodome >= 3', - 'terminaltables >= 3.0.0', + 'cli_helpers[styles] >= 2.2.1', + 'pyperclip >= 1.8.1', + 'pyaes >= 1.6.1' ] +if sys.version_info.minor < 9: + install_requirements.append('importlib_resources >= 5.0.0') + + +class lint(Command): + description = 'check code against PEP 8 (and fix violations)' + + user_options = [ + ('branch=', 'b', 'branch/revision to compare against (e.g. master)'), + ('fix', 'f', 'fix the violations in place'), + ('error-status', 'e', 'return an error code on failed PEP check'), + ] + + def initialize_options(self): + """Set the default options.""" + self.branch = 'master' + self.fix = False + self.error_status = True + + def finalize_options(self): + pass + + def run(self): + cmd = 'pep8radius {}'.format(self.branch) + if self.fix: + cmd += ' --in-place' + if self.error_status: + cmd += ' --error-status' + sys.exit(subprocess.call(cmd, shell=True)) + + +class test(TestCommand): + + user_options = [ + ('pytest-args=', 'a', 'Arguments to pass to pytest'), + ('behave-args=', 'b', 'Arguments to pass to pytest') + ] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = '' + self.behave_args = '--no-capture' + + def run_tests(self): + unit_test_errno = subprocess.call( + 'pytest test/ ' + self.pytest_args, + shell=True + ) + cli_errno = subprocess.call( + 'behave test/features ' + self.behave_args, + shell=True + ) + subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False) + sys.exit(unit_test_errno or cli_errno) + + setup( name='mycli', author='Mycli Core Team', @@ -29,29 +93,31 @@ version=version, url='http://mycli.net', packages=find_packages(), - package_data={'mycli': ['myclirc', '../AUTHORS', '../SPONSORS']}, + package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']}, description=description, long_description=description, install_requires=install_requirements, - entry_points=''' - [console_scripts] - mycli=mycli.main:cli - ''', + entry_points={ + 'console_scripts': ['mycli = mycli.main:cli'], + }, + cmdclass={'lint': lint, 'test': test}, + python_requires=">=3.6", classifiers=[ 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: Unix', 'Programming Language :: Python', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries :: Python Modules', ], + extras_require={ + 'ssh': ['paramiko'], + }, ) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/test/conftest.py similarity index 59% rename from tests/conftest.py rename to test/conftest.py index d24d26bc..d7d10ce3 100644 --- a/tests/conftest.py +++ b/test/conftest.py @@ -1,9 +1,10 @@ import pytest -from utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection) +from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, + db_connection, SSH_USER, SSH_HOST, SSH_PORT) import mycli.sqlexecute -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def connection(): create_db('_test_db') connection = db_connection('_test_db') @@ -23,4 +24,6 @@ def executor(connection): return mycli.sqlexecute.SQLExecute( database='_test_db', user=USER, host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False) + local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, + ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + ) diff --git a/test/features/__init__.py b/test/features/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/features/auto_vertical.feature b/test/features/auto_vertical.feature new file mode 100644 index 00000000..aa957186 --- /dev/null +++ b/test/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature new file mode 100644 index 00000000..a12e8992 --- /dev/null +++ b/test/features/basic_commands.feature @@ -0,0 +1,19 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 00000000..b06935ea --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,35 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature new file mode 100644 index 00000000..f4a7a7f1 --- /dev/null +++ b/test/features/crud_database.feature @@ -0,0 +1,30 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we confirm the destructive warning + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from quoted test database + When we connect to quoted test database + then we see database connected + + Scenario: create and drop default database + When we create database + then we see database created + when we connect to tmp database + then we see database connected + when we drop database + then we confirm the destructive warning + then we see database dropped and no default database diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature new file mode 100644 index 00000000..3384efd7 --- /dev/null +++ b/test/features/crud_table.feature @@ -0,0 +1,49 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we update table + then we see record updated + when we select from table + then we see data selected + when we delete from table + then we confirm the destructive warning + then we see record deleted + when we drop table + then we confirm the destructive warning + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: select null values + When we connect to test database + then we see database connected + when we select null + then we see null selected + + Scenario: confirm destructive query + When we query "create table foo(x integer);" + and we query "delete from foo;" + and we answer the destructive warning with "y" + then we see text "Your call!" + + Scenario: decline destructive query + When we query "delete from foo;" + and we answer the destructive warning with "n" + then we see text "Wise choice!" + + Scenario: no destructive warning if disabled in config + When we run dbcli with --no-warn + and we query "create table blabla(x integer);" + and we query "delete from blabla;" + Then we see text "Query OK" + + Scenario: confirm destructive query with invalid response + When we query "delete from foo;" + then we answer the destructive warning with invalid "1" and see text "is not a valid boolean" diff --git a/tests/features/db_utils.py b/test/features/db_utils.py similarity index 62% rename from tests/features/db_utils.py rename to test/features/db_utils.py index f604c608..be550e9f 100644 --- a/tests/features/db_utils.py +++ b/test/features/db_utils.py @@ -1,21 +1,21 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from __future__ import print_function - import pymysql -def create_db(hostname='localhost', username=None, password=None, - dbname=None): - """ - Create test database. + +def create_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): + """Create test database. + :param hostname: string + :param port: int :param username: string :param password: string :param dbname: string :return: + """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, charset='utf8mb4', @@ -23,26 +23,29 @@ def create_db(hostname='localhost', username=None, password=None, ) with cn.cursor() as cr: - cr.execute('drop database if exists '+dbname) - cr.execute('create database '+dbname) + cr.execute('drop database if exists ' + dbname) + cr.execute('create database ' + dbname) cn.close() - cn = create_cn(hostname, password, username, dbname) + cn = create_cn(hostname, port, password, username, dbname) return cn -def create_cn(hostname, password, username, dbname): - """ - Open connection to database. +def create_cn(hostname, port, password, username, dbname): + """Open connection to database. + :param hostname: + :param port: :param password: :param username: :param dbname: string :return: psycopg2.connection + """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, db=dbname, @@ -53,17 +56,20 @@ def create_cn(hostname, password, username, dbname): return cn -def drop_db(hostname='localhost', username=None, password=None, - dbname=None): - """ - Drop database. +def drop_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): + """Drop database. + :param hostname: string + :param port: int :param username: string :param password: string :param dbname: string + """ cn = pymysql.connect( host=hostname, + port=port, user=username, password=password, db=dbname, @@ -72,15 +78,16 @@ def drop_db(hostname='localhost', username=None, password=None, ) with cn.cursor() as cr: - cr.execute('drop database if exists '+dbname) + cr.execute('drop database if exists ' + dbname) close_cn(cn) def close_cn(cn=None): - """ - Close connection. + """Close connection. + :param connection: pymysql.connection + """ if cn: cn.close() diff --git a/test/features/environment.py b/test/features/environment.py new file mode 100644 index 00000000..1ea0f086 --- /dev/null +++ b/test/features/environment.py @@ -0,0 +1,176 @@ +import os +import shutil +import sys +from tempfile import mkstemp + +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect + +from steps.wrappers import run_cli, wait_prompt + +test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') + + +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + + +def before_all(context): + """Set env parameters.""" + os.environ['LINES'] = "100" + os.environ['COLUMNS'] = "100" + os.environ['EDITOR'] = 'ex' + os.environ['LC_ALL'] = 'en_US.UTF-8' + os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' + os.environ['MYCLI_HISTFILE'] = os.devnull + + test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + login_path_file = os.path.join(test_dir, 'mylogin.cnf') +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, + '.coveragerc') + + context.exit_sent = False + + vi = '_'.join([str(x) for x in sys.version_info[:3]]) + db_name = get_db_name_from_context(context) + db_name_full = '{0}_{1}'.format(db_name, vi) + + # Store get params from config/environment variables + context.conf = { + 'host': context.config.userdata.get( + 'my_test_host', + os.getenv('PYTEST_HOST', 'localhost') + ), + 'port': context.config.userdata.get( + 'my_test_port', + int(os.getenv('PYTEST_PORT', '3306')) + ), + 'user': context.config.userdata.get( + 'my_test_user', + os.getenv('PYTEST_USER', 'root') + ), + 'pass': context.config.userdata.get( + 'my_test_pass', + os.getenv('PYTEST_PASSWORD', None) + ), + 'cli_command': context.config.userdata.get( + 'my_cli_command', None) or + sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + 'dbname': db_name, + 'dbname_tmp': db_name_full + '_tmp', + 'vi': vi, + 'pager_boundary': '---boundary---', + } + + _, my_cnf = mkstemp() + with open(my_cnf, 'w') as f: + f.write( + '[client]\n' + 'pager={0} {1} {2}\n'.format( + sys.executable, os.path.join(context.package_root, + 'test/features/wrappager.py'), + context.conf['pager_boundary']) + ) + context.conf['defaults-file'] = my_cnf + context.conf['myclirc'] = os.path.join(context.package_root, 'test', + 'myclirc') + + context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], + context.conf['user'], + context.conf['pass'], + context.conf['dbname']) + + context.fixture_data = fixutils.read_fixture_files() + + +def after_all(context): + """Unset env parameters.""" + dbutils.close_cn(context.cn) + dbutils.drop_db(context.conf['host'], context.conf['port'], + context.conf['user'], context.conf['pass'], + context.conf['dbname']) + + # Restore env vars. + #for k, v in context.pgenv.items(): + # if k in os.environ and v is None: + # del os.environ[k] + # elif v: + # os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def before_scenario(context, arg): + with open(test_log_file, 'w') as f: + f.write('') + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) + + +def after_scenario(context, _): + """Cleans up after each test complete.""" + with open(test_log_file) as f: + for line in f: + if 'error' in line.lower(): + raise RuntimeError(f'Error in log file: {line}') + + if hasattr(context, 'cli') and not context.exit_sent: + # Quit nicely. + if not context.atprompt: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + context.cli.expect_exact( + '{0}@{1}:{2}>'.format( + user, host, dbname + ), + timeout=5 + ) + context.cli.sendcontrol('c') + context.cli.sendcontrol('d') + context.cli.expect_exact(pexpect.EOF, timeout=5) + + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + +# TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import ipdb; ipdb.set_trace() diff --git a/tests/features/fixture_data/help.txt b/test/features/fixture_data/help.txt similarity index 100% rename from tests/features/fixture_data/help.txt rename to test/features/fixture_data/help.txt diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt new file mode 100644 index 00000000..2c06d5d2 --- /dev/null +++ b/test/features/fixture_data/help_commands.txt @@ -0,0 +1,31 @@ ++-------------+----------------------------+------------------------------------------------------------+ +| Command | Shortcut | Description | ++-------------+----------------------------+------------------------------------------------------------+ +| \G | \G | Display current query results vertically. | +| \clip | \clip | Copy query to the system clipboard. | +| \dt | \dt[+] [table] | List or describe tables. | +| \e | \e | Edit command with editor (uses $EDITOR). | +| \f | \f [name [args..]] | List or execute favorite queries. | +| \fd | \fd [name] | Delete a favorite query. | +| \fs | \fs name query | Save a favorite query. | +| \l | \l | List databases. | +| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | +| \pipe_once | \| command | Send next result to a subprocess. | +| \timing | \t | Toggle timing of commands. | +| connect | \r | Reconnect to the database. Optional database argument. | +| exit | \q | Exit. | +| help | \? | Show this help. | +| nopager | \n | Disable pager, print to stdout. | +| notee | notee | Stop writing results to an output file. | +| pager | \P [command] | Set PAGER. Print the query results via PAGER. | +| prompt | \R | Change prompt format. | +| quit | \q | Quit. | +| rehash | \# | Refresh auto-completions. | +| source | \. filename | Execute commands from file. | +| status | \s | Get status information from the server. | +| system | system [command] | Execute a system shell commmand. | +| tableformat | \T | Change the table format used to output results. | +| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). | +| use | \u | Change to a new database. | +| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). | ++-------------+----------------------------+------------------------------------------------------------+ diff --git a/tests/features/fixture_utils.py b/test/features/fixture_utils.py similarity index 72% rename from tests/features/fixture_utils.py rename to test/features/fixture_utils.py index f3b490c4..f85e0f65 100644 --- a/tests/features/fixture_utils.py +++ b/test/features/fixture_utils.py @@ -1,26 +1,22 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - import os import io def read_fixture_lines(filename): - """ - Read lines of text from file. + """Read lines of text from file. + :param filename: string name :return: list of strings + """ lines = [] - for line in io.open(filename, 'r', encoding='utf8'): + for line in open(filename): lines.append(line.strip()) return lines def read_fixture_files(): - """ - Read all files inside fixture_data directory. - """ + """Read all files inside fixture_data directory.""" fixture_dict = {} current_dir = os.path.dirname(__file__) diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature new file mode 100644 index 00000000..95366eba --- /dev/null +++ b/test/features/iocommands.feature @@ -0,0 +1,47 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type "select * from abc" in the editor + and we exit the editor + then we see dbcli prompt + and we see "select * from abc" in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we select "select 123456" + and we wait for prompt + and we notee output + and we wait for prompt + then we see 123456 in tee output + + Scenario: set delimiter + When we query "delimiter $" + then delimiter is set to "$" + + Scenario: set delimiter twice + When we query "delimiter $" + and we query "delimiter ]]" + then delimiter is set to "]]" + + Scenario: set delimiter and query on same line + When we query "select 123; delimiter $ select 456 $ delimiter %" + then we see result "123" + and we see result "456" + and delimiter is set to "%" + + Scenario: send output to file + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "system cat /tmp/output1.sql" + then we see result "123" + + Scenario: send output to file two times + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "\o /tmp/output2.sql" + and we query "select 456" + and we query "system cat /tmp/output2.sql" + then we see result "456" + \ No newline at end of file diff --git a/test/features/named_queries.feature b/test/features/named_queries.feature new file mode 100644 index 00000000..5e681ec4 --- /dev/null +++ b/test/features/named_queries.feature @@ -0,0 +1,24 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we use a named query + then we see the named query executed + when we delete a named query + then we see the named query deleted + + Scenario: save, use and delete named queries with parameters + When we connect to test database + then we see database connected + when we save a named query with parameters + then we see the named query saved + when we use named query with parameters + then we see the named query with parameters executed + when we use named query with too few parameters + then we see the named query with parameters fail with missing parameters + when we use named query with too many parameters + then we see the named query with parameters fail with extra parameters diff --git a/tests/features/specials.feature b/test/features/specials.feature similarity index 62% rename from tests/features/specials.feature rename to test/features/specials.feature index 9bacec45..bb367578 100644 --- a/tests/features/specials.feature +++ b/test/features/specials.feature @@ -2,8 +2,6 @@ Feature: Special commands @wip Scenario: run refresh command - When we run dbcli - and we wait for prompt - and we refresh completions + When we refresh completions and we wait for prompt then we see completions refresh started diff --git a/test/features/steps/__init__.py b/test/features/steps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py new file mode 100644 index 00000000..e1cb26f8 --- /dev/null +++ b/test/features/steps/auto_vertical.py @@ -0,0 +1,46 @@ +from textwrap import dedent + +from behave import then, when + +import wrappers +from utils import parse_cli_args_to_dict + + +@when('we run dbcli with {arg}') +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) + + +@when('we execute a small query') +def step_execute_small_query(context): + context.cli.sendline('select 1') + + +@when('we execute a large query') +def step_execute_large_query(context): + context.cli.sendline( + 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + + +@then('we see small results in horizontal format') +def step_see_small_results(context): + wrappers.expect_pager(context, dedent("""\ + +---+\r + | 1 |\r + +---+\r + | 1 |\r + +---+\r + \r + """), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see large results in vertical format') +def step_see_large_results(context): + rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] + expected = ('***************************[ 1. row ]' + '***************************\r\n' + + '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + + wrappers.expect_pager(context, expected, timeout=10) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py new file mode 100644 index 00000000..425ef674 --- /dev/null +++ b/test/features/steps/basic_commands.py @@ -0,0 +1,100 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when +from textwrap import dedent +import tempfile +import wrappers + + +@when('we run dbcli') +def step_run_cli(context): + wrappers.run_cli(context) + + +@when('we wait for prompt') +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """Send Ctrl + D to hopefully exit.""" + context.cli.sendcontrol('d') + context.exit_sent = True + + +@when('we send "\?" command') +def step_send_help(context): + """Send \? + + to see help. + + """ + context.cli.sendline('\\?') + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we send source command') +def step_send_source_command(context): + with tempfile.NamedTemporaryFile() as f: + f.write(b'\?') + f.flush() + context.cli.sendline('\. {0}'.format(f.name)) + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we run query to check application_name') +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" + ) + + +@then(u'we see found') +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf['pager_boundary'] + '\r' + dedent(''' + +-------+\r + | found |\r + +-------+\r + | found |\r + +-------+\r + \r + ''') + context.conf['pager_boundary'], + timeout=5 + ) + + +@then(u'we confirm the destructive warning') +def step_confirm_destructive_command(context): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline('y') + + +@when(u'we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + + +@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + wrappers.expect_exact(context, text, timeout=2) + # we must exit the Click loop, or the feature will hang + context.cli.sendline('n') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 00000000..e16dd867 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,71 @@ +import io +import os +import shlex + +from behave import when, then +import pexpect + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py new file mode 100644 index 00000000..841f37d0 --- /dev/null +++ b/test/features/steps/crud_database.py @@ -0,0 +1,115 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import pexpect + +import wrappers +from behave import when, then + + +@when('we create database') +def step_db_create(context): + """Send create database.""" + context.cli.sendline('create database {0};'.format( + context.conf['dbname_tmp'])) + + context.response = { + 'database_name': context.conf['dbname_tmp'] + } + + +@when('we drop database') +def step_db_drop(context): + """Send drop database.""" + context.cli.sendline('drop database {0};'.format( + context.conf['dbname_tmp'])) + + +@when('we connect to test database') +def step_db_connect_test(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use {0};'.format(db_name)) + + +@when('we connect to quoted test database') +def step_db_connect_quoted_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use `{0}`;'.format(db_name)) + + +@when('we connect to tmp database') +def step_db_connect_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname_tmp'] + context.currentdb = db_name + context.cli.sendline('use {0}'.format(db_name)) + + +@when('we connect to dbserver') +def step_db_connect_dbserver(context): + """Send connect to database.""" + context.currentdb = 'mysql' + context.cli.sendline('use mysql') + + +@then('dbcli exits') +def step_wait_exit(context): + """Make sure the cli exits.""" + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then('we see dbcli prompt') +def step_see_prompt(context): + """Wait to see the prompt.""" + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) + + +@then('we see help output') +def step_see_help(context): + for expected_line in context.fixture_data['help_commands.txt']: + wrappers.expect_exact(context, expected_line, timeout=1) + + +@then('we see database created') +def step_see_db_created(context): + """Wait to see create database output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see database dropped') +def step_see_db_dropped(context): + """Wait to see drop database output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see database dropped and no default database') +def step_see_db_dropped_no_default(context): + """Wait to see drop database output.""" + user = context.conf['user'] + host = context.conf['host'] + database = '(none)' + context.currentdb = None + + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) + + +@then('we see database connected') +def step_see_db_connected(context): + """Wait to see drop database output.""" + wrappers.expect_exact( + context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, '"', timeout=2) + wrappers.expect_exact(context, ' as user "{0}"'.format( + context.conf['user']), timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py new file mode 100644 index 00000000..f715f0ca --- /dev/null +++ b/test/features/steps/crud_table.py @@ -0,0 +1,112 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then +from textwrap import dedent + + +@when('we create table') +def step_create_table(context): + """Send create table.""" + context.cli.sendline('create table a(x text);') + + +@when('we insert into table') +def step_insert_into_table(context): + """Send insert into table.""" + context.cli.sendline('''insert into a(x) values('xxx');''') + + +@when('we update table') +def step_update_table(context): + """Send insert into table.""" + context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + + +@when('we select from table') +def step_select_from_table(context): + """Send select from table.""" + context.cli.sendline('select * from a;') + + +@when('we delete from table') +def step_delete_from_table(context): + """Send deete from table.""" + context.cli.sendline('''delete from a where x = 'yyy';''') + + +@when('we drop table') +def step_drop_table(context): + """Send drop table.""" + context.cli.sendline('drop table a;') + + +@then('we see table created') +def step_see_table_created(context): + """Wait to see create table output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see record inserted') +def step_see_record_inserted(context): + """Wait to see insert output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see record updated') +def step_see_record_updated(context): + """Wait to see update output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see data selected') +def step_see_data_selected(context): + """Wait to see select output.""" + wrappers.expect_pager( + context, dedent("""\ + +-----+\r + | x |\r + +-----+\r + | yyy |\r + +-----+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see record deleted') +def step_see_data_deleted(context): + """Wait to see delete output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see table dropped') +def step_see_table_dropped(context): + """Wait to see drop output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@when('we select null') +def step_select_null(context): + """Send select null.""" + context.cli.sendline('select null;') + + +@then('we see null selected') +def step_see_null_selected(context): + """Wait to see null output.""" + wrappers.expect_pager( + context, dedent("""\ + +--------+\r + | NULL |\r + +--------+\r + | |\r + +--------+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py new file mode 100644 index 00000000..bbabf431 --- /dev/null +++ b/test/features/steps/iocommands.py @@ -0,0 +1,105 @@ +import os +import wrappers + +from behave import when, then +from textwrap import dedent + + +@when('we start external editor providing a file name') +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline('\e {0}'.format( + os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we type "{query}" in the editor') +def step_edit_type_sql(context, query): + context.cli.sendline('i') + context.cli.sendline(query) + context.cli.sendline('.') + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we exit the editor') +def step_edit_quit(context): + context.cli.sendline('x') + wrappers.expect_exact(context, "written", timeout=2) + + +@then('we see "{query}" in prompt') +def step_edit_done_sql(context, query): + for match in query.split(' '): + wrappers.expect_exact(context, match, timeout=5) + # Cleanup the command line. + context.cli.sendcontrol('c') + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + + +@when(u'we tee output') +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline('tee {0}'.format( + os.path.basename(context.tee_file_name))) + + +@when(u'we select "select {param}"') +def step_query_select_number(context, param): + context.cli.sendline(u'select {}'.format(param)) + wrappers.expect_pager(context, dedent(u"""\ + +{dashes}+\r + | {param} |\r + +{dashes}+\r + | {param} |\r + +{dashes}+\r + \r + """.format(param=param, dashes='-' * (len(param) + 2)) + ), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then(u'we see result "{result}"') +def step_see_result(context, result): + wrappers.expect_exact( + context, + u"| {} |".format(result), + timeout=2 + ) + + +@when(u'we query "{query}"') +def step_query(context, query): + context.cli.sendline(query) + + +@when(u'we notee output') +def step_notee_output(context): + context.cli.sendline('notee') + + +@then(u'we see 123456 in tee output') +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert '123456' in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + + +@then(u'delimiter is set to "{delimiter}"') +def delimiter_is_set(context, delimiter): + wrappers.expect_exact( + context, + u'Changed delimiter to {}'.format(delimiter), + timeout=2 + ) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py new file mode 100644 index 00000000..bc1f8663 --- /dev/null +++ b/test/features/steps/named_queries.py @@ -0,0 +1,90 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we save a named query') +def step_save_named_query(context): + """Send \fs command.""" + context.cli.sendline('\\fs foo SELECT 12345') + + +@when('we use a named query') +def step_use_named_query(context): + """Send \f command.""" + context.cli.sendline('\\f foo') + + +@when('we delete a named query') +def step_delete_named_query(context): + """Send \fd command.""" + context.cli.sendline('\\fd foo') + + +@then('we see the named query saved') +def step_see_named_query_saved(context): + """Wait to see query saved.""" + wrappers.expect_exact(context, 'Saved.', timeout=2) + + +@then('we see the named query executed') +def step_see_named_query_executed(context): + """Wait to see select output.""" + wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + + +@then('we see the named query deleted') +def step_see_named_query_deleted(context): + """Wait to see query deleted.""" + wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + + +@when('we save a named query with parameters') +def step_save_named_query_with_parameters(context): + """Send \fs command for query with parameters.""" + context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') + + +@when('we use named query with parameters') +def step_use_named_query_with_parameters(context): + """Send \f command with parameters.""" + context.cli.sendline('\\f foo_args 101 second "third value"') + + +@then('we see the named query with parameters executed') +def step_see_named_query_with_parameters_executed(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'SELECT 101, "second", "third value"', timeout=2) + + +@when('we use named query with too few parameters') +def step_use_named_query_with_too_few_parameters(context): + """Send \f command with missing parameters.""" + context.cli.sendline('\\f foo_args 101') + + +@then('we see the named query with parameters fail with missing parameters') +def step_see_named_query_with_parameters_fail_with_missing_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'missing substitution for $2 in query:', timeout=2) + + +@when('we use named query with too many parameters') +def step_use_named_query_with_too_many_parameters(context): + """Send \f command with extra parameters.""" + context.cli.sendline('\\f foo_args 101 102 103 104') + + +@then('we see the named query with parameters fail with extra parameters') +def step_see_named_query_with_parameters_fail_with_extra_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'query does not have substitution parameter $4:', timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py new file mode 100644 index 00000000..e8b99e3e --- /dev/null +++ b/test/features/steps/specials.py @@ -0,0 +1,27 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we refresh completions') +def step_refresh_completions(context): + """Send refresh command.""" + context.cli.sendline('rehash') + + +@then('we see text "{text}"') +def step_see_text(context, text): + """Wait to see given text message.""" + wrappers.expect_exact(context, text, timeout=2) + +@then('we see completions refresh started') +def step_see_refresh_started(context): + """Wait to see refresh output.""" + wrappers.expect_exact( + context, 'Auto-completion refresh started in the background.', timeout=2) diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 00000000..1ae63d2b --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py new file mode 100644 index 00000000..6408f235 --- /dev/null +++ b/test/features/steps/wrappers.py @@ -0,0 +1,117 @@ +import re +import pexpect +import sys +import textwrap + + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', + '', context.cli.before) + raise Exception( + textwrap.dedent('''\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + ''').format( + expected, + actual, + context.logfile.getvalue() + ) + ) + + +def expect_pager(context, expected, timeout): + expect_exact(context, "{0}\r\n{1}{0}\r\n".format( + context.conf['pager_boundary'], expected), timeout=timeout) + + +def run_cli(context, run_args=None, exclude_args=None): + """Run the process using pexpect.""" + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + + try: + cli_cmd = context.conf['cli_command'] + except KeyError: + cli_cmd = ( + '{0!s} -c "' + 'import coverage ; ' + 'coverage.process_startup(); ' + 'import mycli.main; ' + 'mycli.main.cli()' + '"' + ).format(sys.executable) + + cmd_parts = [cli_cmd] + rendered_args + cmd = ' '.join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = context.conf['dbname'] + + +def wait_prompt(context, prompt=None): + """Make sure prompt is displayed.""" + if prompt is None: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + prompt = '{0}@{1}:{2}>'.format( + user, host, dbname), + expect_exact(context, prompt, timeout=5) + context.atprompt = True diff --git a/test/features/wrappager.py b/test/features/wrappager.py new file mode 100755 index 00000000..51d49095 --- /dev/null +++ b/test/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/test/myclirc b/test/myclirc new file mode 100644 index 00000000..261bee6e --- /dev/null +++ b/test/myclirc @@ -0,0 +1,12 @@ +# vi: ft=dosini + +# This file is loaded after mycli/myclirc and should override only those +# variables needed for testing. +# To see what every variable does see mycli/myclirc + +[main] + +log_file = ~/.mycli.test.log +log_level = DEBUG +prompt = '\t \u@\h:\d> ' +less_chatty = True diff --git a/tests/mylogin.cnf b/test/mylogin.cnf similarity index 100% rename from tests/mylogin.cnf rename to test/mylogin.cnf diff --git a/tests/test.txt b/test/test.txt similarity index 100% rename from tests/test.txt rename to test/test.txt diff --git a/test/test_clistyle.py b/test/test_clistyle.py new file mode 100644 index 00000000..f82cdf0e --- /dev/null +++ b/test/test_clistyle.py @@ -0,0 +1,27 @@ +"""Test the mycli.clistyle module.""" +import pytest + +from pygments.style import Style +from pygments.token import Token + +from mycli.clistyle import style_factory + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory(): + """Test that a Pygments Style class is created.""" + header = 'bold underline #ansired' + cli_style = {'Token.Output.Header': header} + style = style_factory('default', cli_style) + + assert isinstance(style(), Style) + assert Token.Output.Header in style.styles + assert header == style.styles[Token.Output.Header] + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory_unknown_name(): + """Test that an unrecognized name will not throw an error.""" + style = style_factory('foobar', {}) + + assert isinstance(style(), Style) diff --git a/tests/test_completion_engine.py b/test/test_completion_engine.py similarity index 77% rename from tests/test_completion_engine.py rename to test/test_completion_engine.py index 9b877108..8b06ed38 100644 --- a/tests/test_completion_engine.py +++ b/test/test_completion_engine.py @@ -1,25 +1,30 @@ from mycli.packages.completion_engine import suggest_type import pytest + def sorted_dicts(dicts): - """input is a list of dicts""" + """input is a list of dicts.""" return sorted(tuple(x.items()) for x in dicts) + def test_select_suggests_cols_with_visible_table_scope(): suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + def test_select_suggests_cols_with_qualified_table_scope(): suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [('sch', 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [('sch', 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) @pytest.mark.parametrize('expression', [ @@ -37,10 +42,12 @@ def test_select_suggests_cols_with_qualified_table_scope(): def test_where_suggests_columns_functions(expression): suggestions = suggest_type(expression, expression) assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + @pytest.mark.parametrize('expression', [ 'SELECT * FROM tabl WHERE foo IN (', @@ -49,41 +56,52 @@ def test_where_suggests_columns_functions(expression): def test_where_in_suggests_columns(expression): suggestions = suggest_type(expression, expression) assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + def test_where_equals_any_suggests_columns_or_keywords(): text = 'SELECT * FROM tabl WHERE foo = ANY(' suggestions = suggest_type(text, text) assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}]) + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}]) + def test_lparen_suggests_cols(): suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') assert suggestion == [ {'type': 'column', 'tables': [(None, 'tbl', None)]}] + def test_operand_inside_function_suggests_cols1(): - suggestion = suggest_type('SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') + suggestion = suggest_type( + 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') assert suggestion == [ {'type': 'column', 'tables': [(None, 'tbl', None)]}] + def test_operand_inside_function_suggests_cols2(): - suggestion = suggest_type('SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') + suggestion = suggest_type( + 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') assert suggestion == [ {'type': 'column', 'tables': [(None, 'tbl', None)]}] + def test_select_suggests_cols_and_funcs(): suggestions = suggest_type('SELECT ', 'SELECT ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': []}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': []}, + {'type': 'column', 'tables': []}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + @pytest.mark.parametrize('expression', [ 'SELECT * FROM ', @@ -102,6 +120,7 @@ def test_expression_suggests_tables_views_and_schemas(expression): {'type': 'view', 'schema': []}, {'type': 'schema'}]) + @pytest.mark.parametrize('expression', [ 'SELECT * FROM sch.', 'INSERT INTO sch.', @@ -118,37 +137,44 @@ def test_expression_suggests_qualified_tables_views_and_schemas(expression): {'type': 'table', 'schema': 'sch'}, {'type': 'view', 'schema': 'sch'}]) + def test_truncate_suggests_tables_and_schemas(): suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'table', 'schema': []}, {'type': 'schema'}]) + def test_truncate_suggests_qualified_tables(): suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'table', 'schema': 'sch'}]) + def test_distinct_suggests_cols(): suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') assert suggestions == [{'type': 'column', 'tables': []}] + def test_col_comma_suggests_cols(): suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tbl']}, {'type': 'column', 'tables': [(None, 'tbl', None)]}, {'type': 'function', 'schema': []}, {'type': 'keyword'}, - ]) + ]) + def test_table_comma_suggests_tables_and_schemas(): suggestions = suggest_type('SELECT a, b FROM tbl1, ', - 'SELECT a, b FROM tbl1, ') + 'SELECT a, b FROM tbl1, ') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'table', 'schema': []}, {'type': 'view', 'schema': []}, {'type': 'schema'}]) + def test_into_suggests_tables_and_schemas(): suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') assert sorted_dicts(suggestion) == sorted_dicts([ @@ -156,26 +182,32 @@ def test_into_suggests_tables_and_schemas(): {'type': 'view', 'schema': []}, {'type': 'schema'}]) + def test_insert_into_lparen_suggests_cols(): suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + def test_insert_into_lparen_partial_text_suggests_cols(): suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + def test_insert_into_lparen_comma_suggests_cols(): suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + def test_partially_typed_col_name_suggests_col_names(): suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', - 'SELECT * FROM tabl WHERE col_n') + 'SELECT * FROM tabl WHERE col_n') assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, {'type': 'column', 'tables': [(None, 'tabl', None)]}, {'type': 'function', 'schema': []}, {'type': 'keyword'}, - ]) + ]) + def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') @@ -185,33 +217,38 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): {'type': 'view', 'schema': 'tabl'}, {'type': 'function', 'schema': 'tabl'}]) + def test_dot_suggests_cols_of_an_alias(): suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.') + 'SELECT t1.') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'table', 'schema': 't1'}, {'type': 'view', 'schema': 't1'}, {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, {'type': 'function', 'schema': 't1'}]) + def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.a, t2.') + 'SELECT t1.a, t2.') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, {'type': 'table', 'schema': 't2'}, {'type': 'view', 'schema': 't2'}, {'type': 'function', 'schema': 't2'}]) + @pytest.mark.parametrize('expression', [ 'SELECT * FROM (', 'SELECT * FROM foo WHERE EXISTS (', 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', + 'SELECT 1 AS', ]) def test_sub_select_suggests_keyword(expression): suggestion = suggest_type(expression, expression) assert suggestion == [{'type': 'keyword'}] + @pytest.mark.parametrize('expression', [ 'SELECT * FROM (S', 'SELECT * FROM foo WHERE EXISTS (S', @@ -221,6 +258,7 @@ def test_sub_select_partial_text_suggests_keyword(expression): suggestion = suggest_type(expression, expression) assert suggestion == [{'type': 'keyword'}] + def test_outer_table_reference_in_exists_subquery_suggests_columns(): q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' suggestions = suggest_type(q, q) @@ -230,6 +268,7 @@ def test_outer_table_reference_in_exists_subquery_suggests_columns(): {'type': 'view', 'schema': 'f'}, {'type': 'function', 'schema': 'f'}] + @pytest.mark.parametrize('expression', [ 'SELECT * FROM (SELECT * FROM ', 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', @@ -242,32 +281,37 @@ def test_sub_select_table_name_completion(expression): {'type': 'view', 'schema': []}, {'type': 'schema'}]) + def test_sub_select_col_name_completion(): suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', - 'SELECT * FROM (SELECT ') + 'SELECT * FROM (SELECT ') assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['abc']}, {'type': 'column', 'tables': [(None, 'abc', None)]}, {'type': 'function', 'schema': []}, {'type': 'keyword'}, - ]) + ]) + @pytest.mark.xfail def test_sub_select_multiple_col_name_completion(): suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', - 'SELECT * FROM (SELECT a, ') + 'SELECT * FROM (SELECT a, ') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'column', 'tables': [(None, 'abc', None)]}, {'type': 'function', 'schema': []}]) + def test_sub_select_dot_col_name_completion(): suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', - 'SELECT * FROM (SELECT t.') + 'SELECT * FROM (SELECT t.') assert sorted_dicts(suggestions) == sorted_dicts([ {'type': 'column', 'tables': [(None, 'tabl', 't')]}, {'type': 'table', 'schema': 't'}, {'type': 'view', 'schema': 't'}, {'type': 'function', 'schema': 't'}]) + @pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) @pytest.mark.parametrize('tbl_alias', ['', 'foo']) def test_join_suggests_tables_and_schemas(tbl_alias, join_type): @@ -278,6 +322,7 @@ def test_join_suggests_tables_and_schemas(tbl_alias, join_type): {'type': 'view', 'schema': []}, {'type': 'schema'}]) + @pytest.mark.parametrize('sql', [ 'SELECT * FROM abc a JOIN def d ON a.', 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', @@ -290,6 +335,7 @@ def test_join_alias_dot_suggests_cols1(sql): {'type': 'view', 'schema': 'a'}, {'type': 'function', 'schema': 'a'}]) + @pytest.mark.parametrize('sql', [ 'SELECT * FROM abc a JOIN def d ON a.id = d.', 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', @@ -302,6 +348,7 @@ def test_join_alias_dot_suggests_cols2(sql): {'type': 'view', 'schema': 'd'}, {'type': 'function', 'schema': 'd'}]) + @pytest.mark.parametrize('sql', [ 'select a.x, b.y from abc a join bcd b on ', 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', @@ -310,6 +357,7 @@ def test_on_suggests_aliases(sql): suggestions = suggest_type(sql, sql) assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + @pytest.mark.parametrize('sql', [ 'select abc.x, bcd.y from abc join bcd on ', 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', @@ -318,6 +366,7 @@ def test_on_suggests_tables(sql): suggestions = suggest_type(sql, sql) assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + @pytest.mark.parametrize('sql', [ 'select a.x, b.y from abc a join bcd b on a.id = ', 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', @@ -326,6 +375,7 @@ def test_on_suggests_aliases_right_side(sql): suggestions = suggest_type(sql, sql) assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + @pytest.mark.parametrize('sql', [ 'select abc.x, bcd.y from abc join bcd on ', 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', @@ -343,62 +393,78 @@ def test_join_using_suggests_common_columns(col_list): 'tables': [(None, 'abc', None), (None, 'def', None)], 'drop_unique': True}] +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', +]) +def test_two_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, + {'type': 'table', 'schema': 'g'}, + {'type': 'view', 'schema': 'g'}, + {'type': 'function', 'schema': 'g'}]) def test_2_statements_2nd_current(): suggestions = suggest_type('select * from a; select * from ', 'select * from a; select * from ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) suggestions = suggest_type('select * from a; select from b', 'select * from a; select ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) # Should work even if first statement is invalid suggestions = suggest_type('select * from; select * from ', 'select * from; select * from ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + def test_2_statements_1st_current(): suggestions = suggest_type('select * from ; select * from b', 'select * from ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) suggestions = suggest_type('select from a; select * from b', 'select ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'a', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['a']}, + {'type': 'column', 'tables': [(None, 'a', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + def test_3_statements_2nd_current(): suggestions = suggest_type('select * from a; select * from ; select * from c', 'select * from a; select * from ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) suggestions = suggest_type('select * from a; select from b; select * from c', 'select * from a; select ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) def test_create_db_with_template(): @@ -435,13 +501,15 @@ def test_handle_pre_completion_comma_gracefully(text): assert iter(suggestions) + def test_cross_join(): text = 'select * from v1 cross join v2 JOIN v1.id, ' suggestions = suggest_type(text, text) assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + @pytest.mark.parametrize('expression', [ 'SELECT 1 AS ', @@ -450,3 +518,31 @@ def test_cross_join(): def test_after_as(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == set() + + +@pytest.mark.parametrize('expression', [ + '\\. ', + 'select 1; \\. ', + 'select 1;\\. ', + 'select 1 ; \\. ', + 'source ', + 'truncate table test; source ', + 'truncate table test ; source ', + 'truncate table test;source ', +]) +def test_source_is_file(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'file_name'}] + + +@pytest.mark.parametrize("expression", [ + "\\f ", +]) +def test_favorite_name_suggestion(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'favoritequery'}] + +def test_order_by(): + text = 'select * from foo order by ' + suggestions = suggest_type(text, text) + assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] diff --git a/tests/test_completion_refresher.py b/test/test_completion_refresher.py similarity index 91% rename from tests/test_completion_refresher.py rename to test/test_completion_refresher.py index 8851eae6..cdc2fb5e 100644 --- a/tests/test_completion_refresher.py +++ b/test/test_completion_refresher.py @@ -1,6 +1,6 @@ import time import pytest -from mock import Mock, patch +from unittest.mock import Mock, patch @pytest.fixture @@ -10,10 +10,11 @@ def refresher(): def test_ctor(refresher): - """ - Refresher object should contain a few handlers + """Refresher object should contain a few handlers. + :param refresher: :return: + """ assert len(refresher.refreshers) > 0 actual_handlers = list(refresher.refreshers.keys()) @@ -41,10 +42,11 @@ def test_refresh_called_once(refresher): def test_refresh_called_twice(refresher): - """ - If refresh is called a second time, it should be restarted + """If refresh is called a second time, it should be restarted. + :param refresher: :return: + """ callbacks = Mock() @@ -69,9 +71,10 @@ def dummy_bg_refresh(*args): def test_refresh_with_callbacks(refresher): - """ - Callbacks must be called + """Callbacks must be called. + :param refresher: + """ callbacks = [Mock()] sqlexecute_class = Mock() diff --git a/tests/test_config.py b/test/test_config.py similarity index 66% rename from tests/test_config.py rename to test/test_config.py index 2a0d26c1..7f2b2442 100644 --- a/tests/test_config.py +++ b/test/test_config.py @@ -1,17 +1,14 @@ """Unit tests for the mycli.config module.""" -from io import BytesIO, TextIOWrapper +from io import BytesIO, StringIO, TextIOWrapper import os -import pip import struct import sys import tempfile import pytest from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, - read_and_decrypt_mylogin_cnf, str_to_bool) - -with_pycryptodome = ['pycryptodome' in set([package.project_name for package in - pip.get_installed_distributions()])] + read_and_decrypt_mylogin_cnf, read_config_file, + str_to_bool, strip_matching_quotes) LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), 'mylogin.cnf')) @@ -24,8 +21,6 @@ def open_bmylogin_cnf(name): buf.write(f.read()) return buf - -@pytest.mark.skipif(not with_pycryptodome, reason='requires pycryptodome') def test_read_mylogin_cnf(): """Tests that a login path file can be read and decrypted.""" mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) @@ -37,14 +32,12 @@ def test_read_mylogin_cnf(): assert word in contents -@pytest.mark.skipif(not with_pycryptodome, reason='requires pycryptodome') def test_decrypt_blank_mylogin_cnf(): """Test that a blank login path file is handled correctly.""" mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO()) assert mylogin_cnf is None -@pytest.mark.skipif(not with_pycryptodome, reason='requires pycryptodome') def test_corrupted_login_key(): """Test that a corrupted login path key is handled correctly.""" buf = open_bmylogin_cnf(LOGIN_PATH_FILE) @@ -61,7 +54,6 @@ def test_corrupted_login_key(): assert mylogin_cnf is None -@pytest.mark.skipif(not with_pycryptodome, reason='requires pycryptodome') def test_corrupted_pad(): """Tests that a login path file with a corrupted pad is partially read.""" buf = open_bmylogin_cnf(LOGIN_PATH_FILE) @@ -146,3 +138,59 @@ def test_str_to_bool(): with pytest.raises(TypeError): str_to_bool(None) + + +def test_read_config_file_list_values_default(): + """Test that reading a config file uses list_values by default.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f) + + assert config['main']['weather'] == u"cloudy with a chance of meatballs" + + +def test_read_config_file_list_values_off(): + """Test that you can disable list_values when reading a config file.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f, list_values=False) + + assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + + +def test_strip_quotes_with_matching_quotes(): + """Test that a string with matching quotes is unquoted.""" + + s = "May the force be with you." + assert s == strip_matching_quotes('"{}"'.format(s)) + assert s == strip_matching_quotes("'{}'".format(s)) + + +def test_strip_quotes_with_unmatching_quotes(): + """Test that a string with unmatching quotes is not unquoted.""" + + s = "May the force be with you." + assert '"' + s == strip_matching_quotes('"{}'.format(s)) + assert s + "'" == strip_matching_quotes("{}'".format(s)) + + +def test_strip_quotes_with_empty_string(): + """Test that an empty string is handled during unquoting.""" + + assert '' == strip_matching_quotes('') + + +def test_strip_quotes_with_none(): + """Test that None is handled during unquoting.""" + + assert None is strip_matching_quotes(None) + + +def test_strip_quotes_with_quotes(): + """Test that strings with quotes in them are handled during unquoting.""" + + s1 = 'Darth Vader said, "Luke, I am your father."' + assert s1 == strip_matching_quotes(s1) + + s2 = '"Darth Vader said, "Luke, I am your father.""' + assert s2[1:-1] == strip_matching_quotes(s2) diff --git a/tests/test_dbspecial.py b/test/test_dbspecial.py similarity index 73% rename from tests/test_dbspecial.py rename to test/test_dbspecial.py index 17309b29..21e389ce 100644 --- a/tests/test_dbspecial.py +++ b/test/test_dbspecial.py @@ -1,11 +1,12 @@ from mycli.packages.completion_engine import suggest_type -from test_completion_engine import sorted_dicts +from .test_completion_engine import sorted_dicts from mycli.packages.special.utils import format_uptime + def test_u_suggests_databases(): suggestions = suggest_type('\\u ', '\\u ') assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'database'}]) + {'type': 'database'}]) def test_describe_table(): @@ -16,6 +17,14 @@ def test_describe_table(): {'type': 'schema'}]) +def test_list_or_show_create_tables(): + suggestions = suggest_type('\\dt+', '\\dt+ ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + def test_format_uptime(): seconds = 59 assert '59 sec' == format_uptime(seconds) diff --git a/test/test_main.py b/test/test_main.py new file mode 100644 index 00000000..7731603e --- /dev/null +++ b/test/test_main.py @@ -0,0 +1,528 @@ +import os +import shutil + +import click +from click.testing import CliRunner + +from mycli.main import MyCli, cli, thanks_picker +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.sqlexecute import ServerInfo +from .utils import USER, HOST, PORT, PASSWORD, dbtest, run + +from textwrap import dedent +from collections import namedtuple + +from tempfile import NamedTemporaryFile +from textwrap import dedent + + +test_dir = os.path.abspath(os.path.dirname(__file__)) +project_dir = os.path.dirname(test_dir) +default_config_file = os.path.join(project_dir, 'test', 'myclirc') +login_path_file = os.path.join(test_dir, 'mylogin.cnf') + +os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, + '--password', PASSWORD, '--myclirc', default_config_file, + '--defaults-file', default_config_file, + '_test_db'] + + +@dbtest +def test_execute_arg(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + expected = 'a\nabc\n' + + assert expected in result.output + + +@dbtest +def test_execute_arg_with_table(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) + expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_arg_with_csv(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) + expected = '"a"\n"abc"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +@dbtest +def test_batch_mode(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert 'count(*)\n3\na\nabc\n' in "".join(result.output) + + +@dbtest +def test_batch_mode_table(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) + + expected = (dedent("""\ + +----------+ + | count(*) | + +----------+ + | 3 | + +----------+ + +-----+ + | a | + +-----+ + | abc | + +-----+""")) + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_batch_mode_csv(executor): + run(executor, '''create table test(a text, b text)''') + run(executor, + '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''') + + sql = 'select * from test;' + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) + + expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +def test_thanks_picker_utf8(): + name = thanks_picker() + assert name and isinstance(name, str) + + +def test_help_strings_end_with_periods(): + """Make sure click options have help text that end with a period.""" + for param in cli.params: + if isinstance(param, click.core.Option): + assert hasattr(param, 'help') + assert param.help.endswith('.') + + +def test_command_descriptions_end_with_periods(): + """Make sure that mycli commands' descriptions end with a period.""" + MyCli() + for _, command in SPECIAL_COMMANDS.items(): + assert command[3].endswith('.') + + +def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): + global clickoutput + clickoutput = "" + m = MyCli(myclirc=default_config_file) + + class TestOutput(): + def get_size(self): + size = namedtuple('Size', 'rows columns') + size.columns, size.rows = terminal_size + return size + + class TestExecute(): + host = 'test' + user = 'test' + dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') + port = 0 + + def server_type(self): + return ['test'] + + class PromptBuffer(): + output = TestOutput() + + m.prompt_app = PromptBuffer() + m.sqlexecute = TestExecute() + m.explicit_pager = explicit_pager + + def echo_via_pager(s): + assert expect_pager + global clickoutput + clickoutput += "".join(s) + + def secho(s): + assert not expect_pager + global clickoutput + clickoutput += s + "\n" + + monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager) + monkeypatch.setattr(click, 'secho', secho) + m.output(testdata) + if clickoutput.endswith("\n"): + clickoutput = clickoutput[:-1] + assert clickoutput == "\n".join(testdata) + + +def test_conditional_pager(monkeypatch): + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( + " ") + # User didn't set pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=True + ) + # User didn't set pager, output fits screen -> no pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + # User manually configured pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + # User manually configured pager, output fit screen -> pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + + SPECIAL_COMMANDS['nopager'].handler() + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + SPECIAL_COMMANDS['pager'].handler('') + + +def test_reserved_space_is_integer(): + """Make sure that reserved space is returned as an integer.""" + def stub_terminal_size(): + return (5, 5) + + old_func = shutil.get_terminal_size + + shutil.get_terminal_size = stub_terminal_size + mycli = MyCli() + assert isinstance(mycli.get_reserved_space(), int) + + shutil.get_terminal_size = old_func + + +def test_list_dsn(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as myclirc: + myclirc.write(dedent("""\ + [alias_dsn] + test = mysql://test/test + """)) + myclirc.flush() + args = ['--list-dsn', '--myclirc', myclirc.name] + result = runner.invoke(cli, args=args) + assert result.output == "test\n" + result = runner.invoke(cli, args=args + ['--verbose']) + assert result.output == "test : mysql://test/test\n" + + +def test_list_ssh_config(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + result = runner.invoke(cli, args=args) + assert "test\n" in result.output + result = runner.invoke(cli, args=args + ['--verbose']) + assert "test : test.example.com\n" in result.output + + +def test_dsn(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + class Logger: + def debug(self, *args, **args_dict): + pass + def warning(self, *args, **args_dict): + pass + class MockMyCli: + config = {'alias_dsn': {}} + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + def connect(self, **args): + MockMyCli.connect_args = args + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # When a user supplies a DSN as database argument to mycli, + # use these values. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 1 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + MockMyCli.connect_args = None + + # When a use supplies a DSN as database argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "3", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 3 and \ + MockMyCli.connect_args["database"] == "arg_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn), + # use these values. + result = runner.invoke(cli, args=['--dsn', 'test']) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "alias_dsn_user" and \ + MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ + MockMyCli.connect_args["host"] == "alias_dsn_host" and \ + MockMyCli.connect_args["port"] == 4 and \ + MockMyCli.connect_args["database"] == "alias_dsn_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn) + # and used command line arguments, use the command line arguments. + result = runner.invoke(cli, args=[ + '--dsn', 'test', '', + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "5", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 5 and \ + MockMyCli.connect_args["database"] == "arg_database" + + # Use a DSN without password + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user@dsn_host:6/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] is None and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 6 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + +def test_ssh_config(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = {'alias_dsn': {}} + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # Setup temporary configuration + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + + # When a user supplies a ssh config. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "joe" and \ + MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ + MockMyCli.connect_args["ssh_port"] == 22222 and \ + MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + "HOME") + "/.ssh/gateway" + + # When a user supplies a ssh config host as argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", "arg_user", + "--ssh-host", "arg_host", + "--ssh-port", "3", + "--ssh-key-filename", "/path/to/key" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "arg_user" and \ + MockMyCli.connect_args["ssh_host"] == "arg_host" and \ + MockMyCli.connect_args["ssh_port"] == 3 and \ + MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + + +@dbtest +def test_init_command_arg(executor): + init_command = "set sql_select_limit=1000" + sql = 'show variables like "sql_select_limit";' + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ["--init-command", init_command], input=sql + ) + + expected = "sql_select_limit\t1000\n" + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_init_command_multiple_arg(executor): + init_command = 'set sql_select_limit=2000; set max_join_size=20000' + sql = ( + 'show variables like "sql_select_limit";\n' + 'show variables like "max_join_size"' + ) + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ['--init-command', init_command], input=sql + ) + + expected_sql_select_limit = 'sql_select_limit\t2000\n' + expected_max_join_size = 'max_join_size\t20000\n' + + assert result.exit_code == 0 + assert expected_sql_select_limit in result.output + assert expected_max_join_size in result.output diff --git a/tests/test_naive_completion.py b/test/test_naive_completion.py similarity index 56% rename from tests/test_naive_completion.py rename to test/test_naive_completion.py index 57d738bc..32b2abdf 100644 --- a/tests/test_naive_completion.py +++ b/test/test_naive_completion.py @@ -1,48 +1,63 @@ -from __future__ import unicode_literals import pytest from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document + @pytest.fixture def completer(): import mycli.sqlcompleter as sqlcompleter return sqlcompleter.SQLCompleter(smart_completion=False) + @pytest.fixture def complete_event(): - from mock import Mock + from unittest.mock import Mock return Mock() + def test_empty_string_completion(completer, complete_event): text = '' position = 0 - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert result == set(map(Completion, completer.all_completions)) + assert result == list(map(Completion, sorted(completer.all_completions))) + def test_select_keyword_completion(completer, complete_event): text = 'SEL' position = len('SEL') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert result == set([Completion(text='SELECT', start_position=-3)]) + assert result == list([Completion(text='SELECT', start_position=-3)]) + def test_function_name_completion(completer, complete_event): text = 'SELECT MA' position = len('SELECT MA') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert result == set([ - Completion(text='MAX', start_position=-2), - Completion(text='MASTER', start_position=-2)]) + assert result == list([ + Completion(text='MASTER', start_position=-2), + Completion(text='MAX', start_position=-2)]) + def test_column_name_completion(completer, complete_event): text = 'SELECT FROM users' position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list(map(Completion, sorted(completer.all_completions))) + + +def test_special_name_completion(completer, complete_event): + text = '\\' + position = len('\\') result = set(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert result == set(map(Completion, completer.all_completions)) + # Special commands will NOT be suggested during naive completion mode. + assert result == set() diff --git a/tests/test_parseutils.py b/test/test_parseutils.py similarity index 53% rename from tests/test_parseutils.py rename to test/test_parseutils.py index e512632c..920a08db 100644 --- a/tests/test_parseutils.py +++ b/test/test_parseutils.py @@ -1,55 +1,69 @@ import pytest -from mycli.packages.parseutils import extract_tables +from mycli.packages.parseutils import ( + extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, + is_dropping_database) def test_empty_string(): tables = extract_tables('') assert tables == [] + def test_simple_select_single_table(): tables = extract_tables('select * from abc') assert tables == [(None, 'abc', None)] + def test_simple_select_single_table_schema_qualified(): tables = extract_tables('select * from abc.def') assert tables == [('abc', 'def', None)] + def test_simple_select_multiple_tables(): tables = extract_tables('select * from abc, def') assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + def test_simple_select_multiple_tables_schema_qualified(): tables = extract_tables('select * from abc.def, ghi.jkl') assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + def test_simple_select_with_cols_single_table(): tables = extract_tables('select a,b from abc') assert tables == [(None, 'abc', None)] + def test_simple_select_with_cols_single_table_schema_qualified(): tables = extract_tables('select a,b from abc.def') assert tables == [('abc', 'def', None)] + def test_simple_select_with_cols_multiple_tables(): tables = extract_tables('select a,b from abc, def') assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + def test_simple_select_with_cols_multiple_tables_with_schema(): tables = extract_tables('select a,b from abc.def, def.ghi') assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + def test_select_with_hanging_comma_single_table(): tables = extract_tables('select a, from abc') assert tables == [(None, 'abc', None)] + def test_select_with_hanging_comma_multiple_tables(): tables = extract_tables('select a, from abc, def') assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + def test_select_with_hanging_period_multiple_tables(): tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + def test_simple_insert_single_table(): tables = extract_tables('insert into abc (id, name) values (1, "def")') @@ -57,27 +71,120 @@ def test_simple_insert_single_table(): # assert tables == [(None, 'abc', None)] assert tables == [(None, 'abc', 'abc')] + @pytest.mark.xfail def test_simple_insert_single_table_schema_qualified(): tables = extract_tables('insert into abc.def (id, name) values (1, "def")') assert tables == [('abc', 'def', None)] + def test_simple_update_table(): tables = extract_tables('update abc set id = 1') assert tables == [(None, 'abc', None)] + def test_simple_update_table_with_schema(): tables = extract_tables('update abc.def set id = 1') assert tables == [('abc', 'def', None)] + def test_join_table(): tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + def test_join_table_schema_qualified(): - tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + tables = extract_tables( + 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + def test_join_as_table(): tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') assert tables == [(None, 'my_table', 'm')] + + +def test_query_starts_with(): + query = 'USE test;' + assert query_starts_with(query, ('use', )) is True + + query = 'DROP DATABASE test;' + assert query_starts_with(query, ('use', )) is False + + +def test_query_starts_with_comment(): + query = '# comment\nUSE test;' + assert query_starts_with(query, ('use', )) is True + + +def test_queries_start_with(): + sql = ( + '# comment\n' + 'show databases;' + 'use foo;' + ) + assert queries_start_with(sql, ('show', 'select')) is True + assert queries_start_with(sql, ('use', 'drop')) is True + assert queries_start_with(sql, ('delete', 'update')) is False + + +def test_is_destructive(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'drop database foo;' + ) + assert is_destructive(sql) is True + + +def test_is_destructive_update_with_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1 WHERE id = 1;' + ) + assert is_destructive(sql) is False + + +def test_is_destructive_update_without_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1;' + ) + assert is_destructive(sql) is True + + +@pytest.mark.parametrize( + ('sql', 'has_where_clause'), + [ + ('update test set dummy = 1;', False), + ('update test set dummy = 1 where id = 1);', True), + ], +) +def test_query_has_where_clause(sql, has_where_clause): + assert query_has_where_clause(sql) is has_where_clause + + +@pytest.mark.parametrize( + ('sql', 'dbname', 'is_dropping'), + [ + ('select bar from foo', 'foo', False), + ('drop database "foo";', '`foo`', True), + ('drop schema foo', 'foo', True), + ('drop schema foo', 'bar', False), + ('drop database bar', 'foo', False), + ('drop database foo', None, False), + ('drop database foo; create database foo', 'foo', False), + ('drop database foo; create database bar', 'foo', True), + ('select bar from foo; drop database bazz', 'foo', False), + ('select bar from foo; drop database bazz', 'bazz', True), + ('-- dropping database \n ' + 'drop -- really dropping \n ' + 'schema abc -- now it is dropped', + 'abc', + True) + ] +) +def test_is_dropping_database(sql, dbname, is_dropping): + assert is_dropping_database(sql, dbname) == is_dropping diff --git a/tests/test_plan.wiki b/test/test_plan.wiki similarity index 100% rename from tests/test_plan.wiki rename to test/test_plan.wiki diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py new file mode 100644 index 00000000..2373fac8 --- /dev/null +++ b/test/test_prompt_utils.py @@ -0,0 +1,11 @@ +import click + +from mycli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream('stdin') + assert stdin.isatty() is False + + sql = 'drop database foo;' + assert confirm_destructive_query(sql) is None diff --git a/tests/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py similarity index 61% rename from tests/test_smart_completion_public_schema_only.py rename to test/test_smart_completion_public_schema_only.py index e99567a0..e7d460a8 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -1,15 +1,16 @@ -# coding: utf-8 -from __future__ import unicode_literals import pytest +from unittest.mock import patch from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document +import mycli.packages.special.main as special metadata = { - 'users': ['id', 'email', 'first_name', 'last_name'], - 'orders': ['id', 'ordered_date', 'status'], - 'select': ['id', 'insert', 'ABC'], - 'réveillé': ['id', 'insert', 'ABC'] - } + 'users': ['id', 'email', 'first_name', 'last_name'], + 'orders': ['id', 'ordered_date', 'status'], + 'select': ['id', 'insert', 'ABC'], + 'réveillé': ['id', 'insert', 'ABC'] +} + @pytest.fixture def completer(): @@ -27,22 +28,36 @@ def completer(): comp.extend_schemata('test') comp.extend_relations(tables, kind='tables') comp.extend_columns(columns, kind='tables') + comp.extend_special_commands(special.COMMANDS) return comp + @pytest.fixture def complete_event(): - from mock import Mock + from unittest.mock import Mock return Mock() + +def test_special_name_completion(completer, complete_event): + text = '\\d' + position = len('\\d') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert result == [Completion(text='\\dt', start_position=-2)] + + def test_empty_string_completion(completer, complete_event): text = '' position = 0 - result = set( + result = list( completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(map(Completion, completer.keywords)) == result + assert list(map(Completion, sorted(completer.keywords) + + sorted(completer.special_commands))) == result + def test_select_keyword_completion(completer, complete_event): text = 'SEL' @@ -50,7 +65,7 @@ def test_select_keyword_completion(completer, complete_event): result = completer.get_completions( Document(text=text, cursor_position=position), complete_event) - assert set(result) == set([Completion(text='SELECT', start_position=-3)]) + assert list(result) == list([Completion(text='SELECT', start_position=-3)]) def test_table_completion(completer, complete_event): @@ -58,10 +73,12 @@ def test_table_completion(completer, complete_event): position = len(text) result = completer.get_completions( Document(text=text, cursor_position=position), complete_event) - assert set(result) == set([Completion(text='users', start_position=0), - Completion(text='`select`', start_position=0), - Completion(text='`réveillé`', start_position=0), - Completion(text='orders', start_position=0)]) + assert list(result) == list([ + Completion(text='`réveillé`', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) def test_function_name_completion(completer, complete_event): @@ -69,228 +86,300 @@ def test_function_name_completion(completer, complete_event): position = len('SELECT MA') result = completer.get_completions( Document(text=text, cursor_position=position), complete_event) - assert set(result) == set([Completion(text='MAX', start_position=-2), + assert list(result) == list([Completion(text='MAX', start_position=-2), Completion(text='MASTER', start_position=-2), ]) + def test_suggested_column_names(completer, complete_event): - """ - Suggest column and function names when selecting from table + """Suggest column and function names when selecting from table. + :param completer: :param complete_event: :return: + """ text = 'SELECT from users' position = len('SELECT ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)] + + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0), + ] + list(map(Completion, completer.functions)) + + [Completion(text='users', start_position=0)] + list(map(Completion, completer.keywords))) + def test_suggested_column_names_in_function(completer, complete_event): - """ - Suggest column and function names when selecting multiple - columns from table + """Suggest column and function names when selecting multiple columns from + table. + :param completer: :param complete_event: :return: + """ text = 'SELECT MAX( from users' position = len('SELECT MAX(') result = completer.get_completions( Document(text=text, cursor_position=position), complete_event) - assert set(result) == set([ + assert list(result) == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)]) + def test_suggested_column_names_with_table_dot(completer, complete_event): - """ - Suggest column names on table name and dot + """Suggest column names on table name and dot. + :param completer: :param complete_event: :return: + """ text = 'SELECT users. from users' position = len('SELECT users.') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)]) + def test_suggested_column_names_with_alias(completer, complete_event): - """ - Suggest column names on table alias and dot + """Suggest column names on table alias and dot. + :param completer: :param complete_event: :return: + """ text = 'SELECT u. from users u' position = len('SELECT u.') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)]) + def test_suggested_multiple_column_names(completer, complete_event): - """ - Suggest column and function names when selecting multiple - columns from table + """Suggest column and function names when selecting multiple columns from + table. + :param completer: :param complete_event: :return: + """ text = 'SELECT id, from users u' position = len('SELECT id, ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)] + list(map(Completion, completer.functions)) + + [Completion(text='u', start_position=0)] + list(map(Completion, completer.keywords))) + def test_suggested_multiple_column_names_with_alias(completer, complete_event): - """ - Suggest column names on table alias and dot - when selecting multiple columns from table + """Suggest column names on table alias and dot when selecting multiple + columns from table. + :param completer: :param complete_event: :return: + """ text = 'SELECT u.id, u. from users u' position = len('SELECT u.id, u.') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)]) + def test_suggested_multiple_column_names_with_dot(completer, complete_event): - """ - Suggest column names on table names and dot - when selecting multiple columns from table + """Suggest column names on table names and dot when selecting multiple + columns from table. + :param completer: :param complete_event: :return: + """ text = 'SELECT users.id, users. from users u' position = len('SELECT users.id, users.') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), Completion(text='email', start_position=0), Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), Completion(text='last_name', start_position=0)]) + def test_suggested_aliases_after_on(completer, complete_event): text = 'SELECT u.name, o.id FROM users u JOIN orders o ON ' position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ - Completion(text='u', start_position=0), - Completion(text='o', start_position=0)]) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + def test_suggested_aliases_after_on_right_side(completer, complete_event): text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ' - position = len('SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ') - result = set(completer.get_completions( + position = len( + 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ') + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ - Completion(text='u', start_position=0), - Completion(text='o', start_position=0)]) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + def test_suggested_tables_after_on(completer, complete_event): text = 'SELECT users.name, orders.id FROM users JOIN orders ON ' position = len('SELECT users.name, orders.id FROM users JOIN orders ON ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0)]) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + def test_suggested_tables_after_on_right_side(completer, complete_event): text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ' - position = len('SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') - result = set(completer.get_completions( + position = len( + 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0)]) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + def test_table_names_after_from(completer, complete_event): text = 'SELECT * FROM ' position = len('SELECT * FROM ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0), + assert result == list([ Completion(text='`réveillé`', start_position=0), Completion(text='`select`', start_position=0), - ]) + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) + def test_auto_escaped_col_names(completer, complete_event): text = 'SELECT from `select`' position = len('SELECT ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == [ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), + Completion(text='`ABC`', start_position=0), Completion(text='`insert`', start_position=0), - Completion(text='`ABC`', start_position=0), ] + - list(map(Completion, completer.functions)) + - list(map(Completion, completer.keywords))) + Completion(text='id', start_position=0), + ] + \ + list(map(Completion, completer.functions)) + \ + [Completion(text='`select`', start_position=0)] + \ + list(map(Completion, completer.keywords)) + def test_un_escaped_table_names(completer, complete_event): text = 'SELECT from réveillé' position = len('SELECT ') - result = set(completer.get_completions( + result = list(completer.get_completions( Document(text=text, cursor_position=position), complete_event)) - assert set(result) == set([ + assert result == list([ Completion(text='*', start_position=0), - Completion(text='id', start_position=0), + Completion(text='`ABC`', start_position=0), Completion(text='`insert`', start_position=0), - Completion(text='`ABC`', start_position=0), ] + + Completion(text='id', start_position=0), + ] + list(map(Completion, completer.functions)) + + [Completion(text='réveillé', start_position=0)] + list(map(Completion, completer.keywords))) + + +def dummy_list_path(dir_name): + dirs = { + '/': [ + 'dir1', + 'file1.sql', + 'file2.sql', + ], + '/dir1': [ + 'subdir1', + 'subfile1.sql', + 'subfile2.sql', + ], + '/dir1/subdir1': [ + 'lastfile.sql', + ], + } + return dirs.get(dir_name, []) + + +@patch('mycli.packages.filepaths.list_path', new=dummy_list_path) +@pytest.mark.parametrize('text,expected', [ + # ('source ', [('~', 0), + # ('/', 0), + # ('.', 0), + # ('..', 0)]), + ('source /', [('dir1', 0), + ('file1.sql', 0), + ('file2.sql', 0)]), + ('source /dir1/', [('subdir1', 0), + ('subfile1.sql', 0), + ('subfile2.sql', 0)]), + ('source /dir1/subdir1/', [('lastfile.sql', 0)]), +]) +def test_file_name_completion(completer, complete_event, text, expected): + position = len(text) + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + expected = list((Completion(txt, pos) for txt, pos in expected)) + assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py new file mode 100644 index 00000000..8b6be337 --- /dev/null +++ b/test/test_special_iocommands.py @@ -0,0 +1,287 @@ +import os +import stat +import tempfile +from time import time +from unittest.mock import patch + +import pytest +from pymysql import ProgrammingError + +import mycli.packages.special + +from .utils import dbtest, db_connection, send_ctrl_c + + +def test_set_get_pager(): + mycli.packages.special.set_pager_enabled(True) + assert mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager_enabled(False) + assert not mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager('less') + assert os.environ['PAGER'] == "less" + mycli.packages.special.set_pager(False) + assert os.environ['PAGER'] == "less" + del os.environ['PAGER'] + mycli.packages.special.set_pager(False) + mycli.packages.special.disable_pager() + assert not mycli.packages.special.is_pager_enabled() + + +def test_set_get_timing(): + mycli.packages.special.set_timing_enabled(True) + assert mycli.packages.special.is_timing_enabled() + mycli.packages.special.set_timing_enabled(False) + assert not mycli.packages.special.is_timing_enabled() + + +def test_set_get_expanded_output(): + mycli.packages.special.set_expanded_output(True) + assert mycli.packages.special.is_expanded_output() + mycli.packages.special.set_expanded_output(False) + assert not mycli.packages.special.is_expanded_output() + + +def test_editor_command(): + assert mycli.packages.special.editor_command(r'hello\e') + assert mycli.packages.special.editor_command(r'\ehello') + assert not mycli.packages.special.editor_command(r'hello') + + assert mycli.packages.special.get_filename(r'\e filename') == "filename" + + os.environ['EDITOR'] = 'true' + os.environ['VISUAL'] = 'true' + mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" + + +def test_tee_command(): + mycli.packages.special.write_tee(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"tee " + f.name) + mycli.packages.special.write_tee(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"tee -o " + f.name) + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"notee") + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + +def test_tee_command_error(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, 'tee') + + with pytest.raises(OSError): + with tempfile.NamedTemporaryFile() as f: + os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + mycli.packages.special.execute(None, 'tee {}'.format(f.name)) + + +@dbtest +def test_favorite_query(): + with db_connection().cursor() as cur: + query = u'select "✔"' + mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) + assert next(mycli.packages.special.execute( + cur, u'\\f check'))[0] == "> " + query + + +def test_once_command(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, u"\\once") + + with pytest.raises(OSError): + mycli.packages.special.execute(None, u"\\once /proc/access-denied") + + mycli.packages.special.write_once(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"\\once " + f.name) + mycli.packages.special.write_once(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"\\once -o " + f.name) + mycli.packages.special.write_once(u"hello world line 1") + mycli.packages.special.write_once(u"hello world line 2") + f.seek(0) + assert f.read() == b"hello world line 1\nhello world line 2\n" + + +def test_pipe_once_command(): + with pytest.raises(IOError): + mycli.packages.special.execute(None, u"\\pipe_once") + + with pytest.raises(OSError): + mycli.packages.special.execute( + None, u"\\pipe_once /proc/access-denied") + + mycli.packages.special.execute(None, u"\\pipe_once wc") + mycli.packages.special.write_once(u"hello world") + mycli.packages.special.unset_pipe_once_if_written() + # how to assert on wc output? + + +def test_parseargfile(): + """Test that parseargfile expands the user directory.""" + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'a'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '~/filename') + + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'w'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '-o ~/filename') + + +def test_parseargfile_no_file(): + """Test that parseargfile raises a TypeError if there is no filename.""" + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('') + + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('-o ') + + +@dbtest +def test_watch_query_iteration(): + """Test that a single iteration of the result of `watch_query` executes + the desired query and returns the given results.""" + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + with db_connection().cursor() as cur: + result = next(mycli.packages.special.iocommands.watch_query( + arg=query, cur=cur + )) + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +def test_watch_query_full(): + """Test that `watch_query`: + + * Returns the expected results. + * Executes the defined times inside the given interval, in this case with + a 0.3 seconds wait, it should execute 4 times inside a 1 seconds + interval. + * Stops at Ctrl-C + + """ + watch_seconds = 0.3 + wait_interval = 1 + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + expected_results = 4 + ctrl_c_process = send_ctrl_c(wait_interval) + with db_connection().cursor() as cur: + results = list( + result for result in mycli.packages.special.iocommands.watch_query( + arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur + ) + ) + ctrl_c_process.join(1) + assert len(results) == expected_results + for result in results: + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +@patch('click.clear') +def test_watch_query_clear(clear_mock): + """Test that the screen is cleared with the -c flag of `watch` command + before execute the query.""" + with db_connection().cursor() as cur: + watch_gen = mycli.packages.special.iocommands.watch_query( + arg='0.1 -c select 1;', cur=cur + ) + assert not clear_mock.called + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + + +@dbtest +def test_watch_query_bad_arguments(): + """Test different incorrect combinations of arguments for `watch` + command.""" + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + with pytest.raises(ProgrammingError): + next(watch_query('a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('1 -a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-c -a select 1;', cur=cur)) + + +@dbtest +@patch('click.clear') +def test_watch_query_interval_clear(clear_mock): + """Test `watch` command with interval and clear flag.""" + def test_asserts(gen): + clear_mock.reset_mock() + start = time() + next(gen) + assert clear_mock.called + next(gen) + exec_time = time() - start + assert exec_time > seconds and exec_time < (seconds + seconds) + + seconds = 1.0 + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), + cur=cur)) + test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), + cur=cur)) + + +def test_split_sql_by_delimiter(): + for delimiter_str in (';', '$', '😀'): + mycli.packages.special.set_delimiter(delimiter_str) + sql_input = "select 1{} select \ufffc2".format(delimiter_str) + queries = ( + "select 1", + "select \ufffc2" + ) + for query, parsed_query in zip( + queries, mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_switch_delimiter_within_query(): + mycli.packages.special.set_delimiter(';') + sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" + queries = ( + "select 1", + "delimiter $$ select 2 $$ select 3 $$", + "select 2", + "select 3" + ) + for query, parsed_query in zip( + queries, + mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_set_delimiter(): + + for delim in ('foo', 'bar'): + mycli.packages.special.set_delimiter(delim) + assert mycli.packages.special.get_current_delimiter() == delim + + +def teardown_function(): + mycli.packages.special.set_delimiter(';') diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py new file mode 100644 index 00000000..0f38a97e --- /dev/null +++ b/test/test_sqlexecute.py @@ -0,0 +1,294 @@ +import os + +import pytest +import pymysql + +from mycli.sqlexecute import ServerInfo, ServerSpecies +from .utils import run, dbtest, set_expanded_output, is_expanded_output + + +def assert_result_equal(result, title=None, rows=None, headers=None, + status=None, auto_status=True, assert_contains=False): + """Assert that an sqlexecute.run() result matches the expected values.""" + if status is None and auto_status and rows: + status = '{} row{} in set'.format( + len(rows), 's' if len(rows) > 1 else '') + fields = {'title': title, 'rows': rows, 'headers': headers, + 'status': status} + + if assert_contains: + # Do a loose match on the results using the *in* operator. + for key, field in fields.items(): + if field: + assert field in result[0][key] + else: + # Do an exact match on the fields. + assert result == [fields] + + +@dbtest +def test_conn(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[('abc',)]) + + +@dbtest +def test_bools(executor): + run(executor, '''create table test(a boolean)''') + run(executor, '''insert into test values(True)''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[(1,)]) + + +@dbtest +def test_binary(executor): + run(executor, '''create table bt(geom linestring NOT NULL)''') + run(executor, "INSERT INTO bt VALUES " + "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, '''select * from bt''') + + geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' + b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' + b'\xac\xdeC@') + + assert_result_equal(results, headers=['geom'], rows=[(geom,)]) + + +@dbtest +def test_table_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + + assert set(executor.tables()) == set([('a',), ('b',)]) + assert set(executor.table_columns()) == set( + [('a', 'x'), ('a', 'y'), ('b', 'z')]) + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert '_test_db' in databases + + +@dbtest +def test_invalid_syntax(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, 'invalid syntax!') + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_invalid_column_name(executor): + with pytest.raises(pymysql.err.OperationalError) as excinfo: + run(executor, 'select invalid command') + assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) + + +@dbtest +def test_unicode_support_in_output(executor): + run(executor, "create table unicodechars(t text)") + run(executor, u"insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + results = run(executor, u"select * from unicodechars") + assert_result_equal(results, headers=['t'], rows=[(u'é',)]) + + +@dbtest +def test_multiple_queries_same_line(executor): + results = run(executor, "select 'foo'; select 'bar'") + + expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], + 'status': '1 row in set'}, + {'title': None, 'headers': ['bar'], 'rows': [('bar',)], + 'status': '1 row in set'}] + assert expected == results + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, "select 'foo'; invalid syntax") + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-a select * from test where a like 'a%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-a") + assert_result_equal(results, + title="> select * from test where a like 'a%'", + headers=['a'], rows=[('abc',)], auto_status=False) + + results = run(executor, "\\fd test-a") + assert_result_equal(results, status='test-a: Deleted') + + +@dbtest +def test_favorite_query_multiple_statement(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, + "\\fs test-ad select * from test where a like 'a%'; " + "select * from test where a like 'd%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ad") + expected = [{'title': "> select * from test where a like 'a%'", + 'headers': ['a'], 'rows': [('abc',)], 'status': None}, + {'title': "> select * from test where a like 'd%'", + 'headers': ['a'], 'rows': [('def',)], 'status': None}] + assert expected == results + + results = run(executor, "\\fd test-ad") + assert_result_equal(results, status='test-ad: Deleted') + + +@dbtest +def test_favorite_query_expanded_output(executor): + set_expanded_output(False) + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + + results = run(executor, "\\fs test-ae select * from test") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ae \\G") + assert is_expanded_output() is True + assert_result_equal(results, title='> select * from test', + headers=['a'], rows=[('abc',)], auto_status=False) + + set_expanded_output(False) + + results = run(executor, "\\fd test-ae") + assert_result_equal(results, status='test-ae: Deleted') + + +@dbtest +def test_special_command(executor): + results = run(executor, '\\?') + assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), + headers='Command', assert_contains=True, + auto_status=False) + + +@dbtest +def test_cd_command_without_a_folder_name(executor): + results = run(executor, 'system cd') + assert_result_equal(results, status='No folder name was provided.') + + +@dbtest +def test_system_command_not_found(executor): + results = run(executor, 'system xyz') + assert_result_equal(results, status='OSError: No such file or directory', + assert_contains=True) + + +@dbtest +def test_system_command_output(executor): + test_dir = os.path.abspath(os.path.dirname(__file__)) + test_file_path = os.path.join(test_dir, 'test.txt') + results = run(executor, 'system cat {0}'.format(test_file_path)) + assert_result_equal(results, status='mycli rocks!\n') + + +@dbtest +def test_cd_command_current_dir(executor): + test_path = os.path.abspath(os.path.dirname(__file__)) + run(executor, 'system cd {0}'.format(test_path)) + assert os.getcwd() == test_path + + +@dbtest +def test_unicode_support(executor): + results = run(executor, u"SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) + + +@dbtest +def test_timestamp_null(executor): + run(executor, '''create table ts_null(a timestamp null)''') + run(executor, '''insert into ts_null values(null)''') + results = run(executor, '''select * from ts_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_datetime_null(executor): + run(executor, '''create table dt_null(a datetime null)''') + run(executor, '''insert into dt_null values(null)''') + results = run(executor, '''select * from dt_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_date_null(executor): + run(executor, '''create table date_null(a date null)''') + run(executor, '''insert into date_null values(null)''') + results = run(executor, '''select * from date_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_time_null(executor): + run(executor, '''create table time_null(a time null)''') + run(executor, '''insert into time_null values(null)''') + results = run(executor, '''select * from time_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_multiple_results(executor): + query = '''CREATE PROCEDURE dmtest() + BEGIN + SELECT 1; + SELECT 2; + END''' + executor.conn.cursor().execute(query) + + results = run(executor, 'call dmtest;') + expected = [ + {'title': None, 'rows': [(1,)], 'headers': ['1'], + 'status': '1 row in set'}, + {'title': None, 'rows': [(2,)], 'headers': ['2'], + 'status': '1 row in set'} + ] + assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py new file mode 100644 index 00000000..c20c7de2 --- /dev/null +++ b/test/test_tabular_output.py @@ -0,0 +1,118 @@ +"""Test the sql output adapter.""" + +from textwrap import dedent + +from mycli.packages.tabular_output import sql_format +from cli_helpers.tabular_output import TabularOutputFormatter + +from .utils import USER, PASSWORD, HOST, PORT, dbtest + +import pytest +from mycli.main import MyCli + +from pymysql.constants import FIELD_TYPE + + +@pytest.fixture +def mycli(): + cli = MyCli() + cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None) + return cli + + +@dbtest +def test_sql_output(mycli): + """Test the sql output adapter.""" + headers = ['letters', 'number', 'optional', 'float', 'binary'] + + class FakeCursor(object): + def __init__(self): + self.data = [ + ('abc', 1, None, 10.0, b'\xAA'), + ('d', 456, '1', 0.5, b'\xAA\xBB') + ] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + (None, FIELD_TYPE.BLOB) + ] + + def __iter__(self): + return self + + def __next__(self): + if self.data: + return self.data.pop(0) + else: + raise StopIteration() + + def description(self): + return self.description + + # Test sql-update output format + assert list(mycli.change_table_format("sql-update")) == \ + [(None, None, None, 'Changed table format to sql-update')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + actual = "\n".join(output) + assert actual == dedent('''\ + UPDATE `DUAL` SET + `number` = 1 + , `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc'; + UPDATE `DUAL` SET + `number` = 456 + , `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd';''') + # Test sql-update-2 output format + assert list(mycli.change_table_format("sql-update-2")) == \ + [(None, None, None, 'Changed table format to sql-update-2')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + UPDATE `DUAL` SET + `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc' AND `number` = 1; + UPDATE `DUAL` SET + `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd' AND `number` = 456;''') + # Test sql-insert output format (without table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with database + table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `database`.`table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 00000000..66b41940 --- /dev/null +++ b/test/utils.py @@ -0,0 +1,94 @@ +import os +import time +import signal +import platform +import multiprocessing + +import pymysql +import pytest + +from mycli.main import special + +PASSWORD = os.getenv('PYTEST_PASSWORD') +USER = os.getenv('PYTEST_USER', 'root') +HOST = os.getenv('PYTEST_HOST', 'localhost') +PORT = int(os.getenv('PYTEST_PORT', 3306)) +CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') +SSH_USER = os.getenv('PYTEST_SSH_USER', None) +SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) +SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) + + +def db_connection(dbname=None): + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, + password=PASSWORD, charset=CHARSET, + local_infile=False) + conn.autocommit = True + return conn + + +try: + db_connection() + CAN_CONNECT_TO_DB = True +except: + CAN_CONNECT_TO_DB = False + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a mysql instance at localhost accessible by user 'root'") + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute('''DROP DATABASE IF EXISTS _test_db''') + cur.execute('''CREATE DATABASE _test_db''') + except: + pass + + +def run(executor, sql, rows_as_list=True): + """Return string output for the sql to be run.""" + result = [] + + for title, rows, headers, status in executor.run(sql): + rows = list(rows) if (rows_as_list and rows) else rows + result.append({'title': title, 'rows': rows, 'headers': headers, + 'status': status}) + + return result + + +def set_expanded_output(is_expanded): + """Pass-through for the tests.""" + return special.set_expanded_output(is_expanded) + + +def is_expanded_output(): + """Pass-through for the tests.""" + return special.is_expanded_output() + + +def send_ctrl_c_to_pid(pid, wait_seconds): + """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds` + seconds.""" + time.sleep(wait_seconds) + system_name = platform.system() + if system_name == "Windows": + os.kill(pid, signal.CTRL_C_EVENT) + else: + os.kill(pid, signal.SIGINT) + + +def send_ctrl_c(wait_seconds): + """Create a process that sends a Ctrl-C like signal to the current process + after `wait_seconds` seconds. + + Returns the `multiprocessing.Process` created. + + """ + ctrl_c_process = multiprocessing.Process( + target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) + ) + ctrl_c_process.start() + return ctrl_c_process diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature deleted file mode 100644 index 227fe769..00000000 --- a/tests/features/basic_commands.feature +++ /dev/null @@ -1,19 +0,0 @@ -Feature: run the cli, - call the help command, - exit the cli - - Scenario: run the cli - When we run dbcli - then we see dbcli prompt - - Scenario: run "\?" command - When we run dbcli - and we wait for prompt - and we send "\?" command - then we see help output - - Scenario: run the cli and exit - When we run dbcli - and we wait for prompt - and we send "ctrl + d" - then dbcli exits diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature deleted file mode 100644 index c72468c3..00000000 --- a/tests/features/crud_database.feature +++ /dev/null @@ -1,20 +0,0 @@ -Feature: manipulate databases: - create, drop, connect, disconnect - - Scenario: create and drop temporary database - When we run dbcli - and we wait for prompt - and we create database - then we see database created - when we drop database - then we see database dropped - when we connect to dbserver - then we see database connected - - Scenario: connect and disconnect from test database - When we run dbcli - and we wait for prompt - and we connect to test database - then we see database connected - when we connect to dbserver - then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature deleted file mode 100644 index d2209fd0..00000000 --- a/tests/features/crud_table.feature +++ /dev/null @@ -1,22 +0,0 @@ -Feature: manipulate tables: - create, insert, update, select, delete from, drop - - Scenario: create, insert, select from, update, drop table - When we run dbcli - and we wait for prompt - and we connect to test database - then we see database connected - when we create table - then we see table created - when we insert into table - then we see record inserted - when we update table - then we see record updated - when we select from table - then we see data selected - when we delete from table - then we see record deleted - when we drop table - then we see table dropped - when we connect to dbserver - then we see database connected diff --git a/tests/features/environment.py b/tests/features/environment.py deleted file mode 100644 index 3f55757b..00000000 --- a/tests/features/environment.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from __future__ import print_function - -import os -import sys -import db_utils as dbutils -import fixture_utils as fixutils - - -def before_all(context): - """ - Set env parameters. - """ - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['PAGER'] = 'cat' - os.environ['EDITOR'] = 'ex' - os.environ["COVERAGE_PROCESS_START"] = os.getcwd() + "/../.coveragerc" - - context.exit_sent = False - - vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get('my_test_db', None) or "mycli_behave_tests" - db_name_full = '{0}_{1}'.format(db_name, vi) - - # Store get params from config/environment variables - context.conf = { - 'host': context.config.userdata.get( - 'my_test_host', - os.getenv('PYTEST_HOST', 'localhost') - ), - 'user': context.config.userdata.get( - 'my_test_user', - os.getenv('PYTEST_USER', 'root') - ), - 'pass': context.config.userdata.get( - 'my_test_pass', - os.getenv('PYTEST_PASSWORD', None) - ), - 'cli_command': context.config.userdata.get( - 'my_cli_command', None) or - sys.executable+' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', - 'dbname': db_name, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - } - - context.cn = dbutils.create_db(context.conf['host'], context.conf['user'], - context.conf['pass'], - context.conf['dbname']) - - context.fixture_data = fixutils.read_fixture_files() - - -def after_all(context): - """ - Unset env parameters. - """ - dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['user'], - context.conf['pass'], context.conf['dbname']) - - # Restore env vars. - #for k, v in context.pgenv.items(): - # if k in os.environ and v is None: - # del os.environ[k] - # elif v: - # os.environ[k] = v - - -def after_scenario(context, _): - """ - Cleans up after each test complete. - """ - - if hasattr(context, 'cli') and not context.exit_sent: - # Terminate nicely. - context.cli.terminate() - -# TODO: uncomment to debug a failure -# def after_step(context, step): -# if step.status == "failed": -# import ipdb; ipdb.set_trace() diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt deleted file mode 100644 index ee3d4ca4..00000000 --- a/tests/features/fixture_data/help_commands.txt +++ /dev/null @@ -1,27 +0,0 @@ -+-------------+-------------------+---------------------------------------------------------+ -| Command | Shortcut | Description | -+-------------+-------------------+---------------------------------------------------------+ -| \G | \G | Display results vertically. | -| \dt | \dt [table] | List or describe tables. | -| \e | \e | Edit command with editor. (uses $EDITOR) | -| \f | \f [name] | List or execute favorite queries. | -| \fd | \fd [name] | Delete a favorite query. | -| \fs | \fs name query | Save a favorite query. | -| \l | \l | List databases. | -| \timing | \t | Toggle timing of commands. | -| connect | \r | Reconnect to the database. Optional database argument. | -| exit | \q | Exit. | -| help | \? | Show this help. | -| nopager | \n | Disable pager, print to stdout. | -| notee | notee | stop writing to an output file | -| pager | \P [command] | Set PAGER. Print the query results via PAGER | -| prompt | \R | Change prompt format. | -| quit | \q | Quit. | -| rehash | \# | Refresh auto-completions. | -| source | \. filename | Execute commands from file. | -| status | \s | Get status information from the server. | -| system | system [command] | Execute a system commmand. | -| tableformat | \T | Change Table Type. | -| tee | tee [-o] filename | write to an output file (optionally overwrite using -o) | -| use | \u | Change to a new database. | -+-------------+-------------------+---------------------------------------------------------+ diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature deleted file mode 100644 index d043dc2e..00000000 --- a/tests/features/iocommands.feature +++ /dev/null @@ -1,10 +0,0 @@ -Feature: I/O commands - - Scenario: edit sql in file with external editor - When we run dbcli - and we wait for prompt - and we start external editor providing a file name - and we type sql in the editor - and we exit the editor - then we see dbcli prompt - and we see the sql in prompt diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature deleted file mode 100644 index 79f31ac3..00000000 --- a/tests/features/named_queries.feature +++ /dev/null @@ -1,12 +0,0 @@ -Feature: named queries: - save, use and delete named queries - - Scenario: save, use and delete named queries - When we run dbcli - and we wait for prompt - and we connect to test database - then we see database connected - when we save a named query - then we see the named query saved - when we delete a named query - then we see the named query deleted diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py deleted file mode 100644 index 7aa47664..00000000 --- a/tests/features/steps/basic_commands.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -""" -Steps for behavioral style tests are defined in this module. -Each step is defined by the string decorating it. -This string is used to call the step in "*.feature" file. -""" -from __future__ import unicode_literals - -import pexpect - -from behave import when -import wrappers - - -@when('we run dbcli') -def step_run_cli(context): - """ - Run the process using pexpect. - """ - run_args = [] - if context.conf.get('host', None): - run_args.extend(('-h', context.conf['host'])) - if context.conf.get('user', None): - run_args.extend(('-u', context.conf['user'])) - if context.conf.get('pass', None): - run_args.extend(('-p', context.conf['pass'])) - if context.conf.get('dbname', None): - run_args.extend(('-D', context.conf['dbname'])) - cli_cmd = context.conf.get('cli_command', None) or sys.executable+' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"' - - cmd_parts = [cli_cmd] + run_args - cmd = ' '.join(cmd_parts) - context.cli = pexpect.spawnu(cmd, cwd='..') - context.exit_sent = False - - -@when('we wait for prompt') -def step_wait_prompt(context): - """ - Make sure prompt is displayed. - """ - user = context.conf['user'] - host = context.conf['host'] - dbname = context.conf['dbname'] - wrappers.expect_exact(context, 'mysql {0}@{1}:{2}> '.format(user, host, dbname), timeout=5) - - -@when('we send "ctrl + d"') -def step_ctrl_d(context): - """ - Send Ctrl + D to hopefully exit. - """ - context.cli.sendcontrol('d') - context.exit_sent = True - - -@when('we send "\?" command') -def step_send_help(context): - """ - Send \? to see help. - """ - context.cli.sendline('\\?') diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py deleted file mode 100644 index 3eab34d9..00000000 --- a/tests/features/steps/crud_database.py +++ /dev/null @@ -1,104 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Steps for behavioral style tests are defined in this module. -Each step is defined by the string decorating it. -This string is used to call the step in "*.feature" file. -""" -from __future__ import unicode_literals - -import pexpect - -import wrappers -from behave import when, then - - -@when('we create database') -def step_db_create(context): - """ - Send create database. - """ - context.cli.sendline('create database {0};'.format( - context.conf['dbname_tmp'])) - - context.response = { - 'database_name': context.conf['dbname_tmp'] - } - - -@when('we drop database') -def step_db_drop(context): - """ - Send drop database. - """ - context.cli.sendline('drop database {0};'.format( - context.conf['dbname_tmp'])) - - wrappers.expect_exact(context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') - -@when('we connect to test database') -def step_db_connect_test(context): - """ - Send connect to database. - """ - db_name = context.conf['dbname'] - context.cli.sendline('use {0}'.format(db_name)) - - -@when('we connect to dbserver') -def step_db_connect_dbserver(context): - """ - Send connect to database. - """ - context.cli.sendline('use mysql') - - -@then('dbcli exits') -def step_wait_exit(context): - """ - Make sure the cli exits. - """ - wrappers.expect_exact(context, pexpect.EOF, timeout=5) - - -@then('we see dbcli prompt') -def step_see_prompt(context): - """ - Wait to see the prompt. - """ - user = context.conf['user'] - host = context.conf['host'] - dbname = context.conf['dbname'] - wrappers.expect_exact(context, 'mysql {0}@{1}:{2}> '.format(user, host, dbname), timeout=5) - - -@then('we see help output') -def step_see_help(context): - for expected_line in context.fixture_data['help_commands.txt']: - wrappers.expect_exact(context, expected_line+'\r\n', timeout=1) - - -@then('we see database created') -def step_see_db_created(context): - """ - Wait to see create database output. - """ - wrappers.expect_exact(context, 'Query OK, 1 row affected\r\n', timeout=2) - - -@then('we see database dropped') -def step_see_db_dropped(context): - """ - Wait to see drop database output. - """ - wrappers.expect_exact(context, 'Query OK, 0 rows affected\r\n', timeout=2) - - -@then('we see database connected') -def step_see_db_connected(context): - """ - Wait to see drop database output. - """ - wrappers.expect_exact(context, 'You are now connected to database "', timeout=2) - wrappers.expect_exact(context, '"', timeout=2) - wrappers.expect_exact(context, ' as user "{0}"\r\n'.format(context.conf['user']), timeout=2) diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py deleted file mode 100644 index b73ff0d6..00000000 --- a/tests/features/steps/crud_table.py +++ /dev/null @@ -1,111 +0,0 @@ -# -*- coding: utf-8 -""" -Steps for behavioral style tests are defined in this module. -Each step is defined by the string decorating it. -This string is used to call the step in "*.feature" file. -""" -from __future__ import unicode_literals - -import wrappers -from behave import when, then - - -@when('we create table') -def step_create_table(context): - """ - Send create table. - """ - context.cli.sendline('create table a(x text);') - - -@when('we insert into table') -def step_insert_into_table(context): - """ - Send insert into table. - """ - context.cli.sendline('''insert into a(x) values('xxx');''') - - -@when('we update table') -def step_update_table(context): - """ - Send insert into table. - """ - context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') - - -@when('we select from table') -def step_select_from_table(context): - """ - Send select from table. - """ - context.cli.sendline('select * from a;') - - -@when('we delete from table') -def step_delete_from_table(context): - """ - Send deete from table. - """ - context.cli.sendline('''delete from a where x = 'yyy';''') - wrappers.expect_exact(context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') - - -@when('we drop table') -def step_drop_table(context): - """ - Send drop table. - """ - context.cli.sendline('drop table a;') - wrappers.expect_exact(context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') - - -@then('we see table created') -def step_see_table_created(context): - """ - Wait to see create table output. - """ - wrappers.expect_exact(context, 'Query OK, 0 rows affected\r\n', timeout=2) - - -@then('we see record inserted') -def step_see_record_inserted(context): - """ - Wait to see insert output. - """ - wrappers.expect_exact(context, 'Query OK, 1 row affected\r\n', timeout=2) - - -@then('we see record updated') -def step_see_record_updated(context): - """ - Wait to see update output. - """ - wrappers.expect_exact(context, 'Query OK, 1 row affected\r\n', timeout=2) - - -@then('we see data selected') -def step_see_data_selected(context): - """ - Wait to see select output. - """ - wrappers.expect_exact( - context, '+-----+\r\n| x |\r\n+-----+\r\n| yyy |\r\n+-----+\r\n1 row in set\r\n', timeout=1) - - -@then('we see record deleted') -def step_see_data_deleted(context): - """ - Wait to see delete output. - """ - wrappers.expect_exact(context, 'Query OK, 1 row affected\r\n', timeout=2) - - -@then('we see table dropped') -def step_see_table_dropped(context): - """ - Wait to see drop output. - """ - wrappers.expect_exact(context, 'Query OK, 0 rows affected\r\n', timeout=2) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py deleted file mode 100644 index 88520046..00000000 --- a/tests/features/steps/iocommands.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals -import os -import wrappers - -from behave import when, then - - -@when('we start external editor providing a file name') -def step_edit_file(context): - """ - Edit file with external editor. - """ - context.editor_file_name = 'test_file_{0}.sql'.format(context.conf['vi']) - if os.path.exists(context.editor_file_name): - os.remove(context.editor_file_name) - context.cli.sendline('\e {0}'.format(context.editor_file_name)) - wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) - wrappers.expect_exact(context, '\r\n:', timeout=2) - - -@when('we type sql in the editor') -def step_edit_type_sql(context): - context.cli.sendline('i') - context.cli.sendline('select * from abc') - context.cli.sendline('.') - wrappers.expect_exact(context, ':', timeout=2) - - -@when('we exit the editor') -def step_edit_quit(context): - context.cli.sendline('x') - wrappers.expect_exact(context, "written", timeout=2) - - -@then('we see the sql in prompt') -def step_edit_done_sql(context): - for match in 'select * from abc'.split(' '): - wrappers.expect_exact(context, match, timeout=1) - # Cleanup the command line. - context.cli.sendcontrol('u') - # Cleanup the edited file. - if context.editor_file_name and os.path.exists(context.editor_file_name): - os.remove(context.editor_file_name) diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py deleted file mode 100644 index b53ad47d..00000000 --- a/tests/features/steps/named_queries.py +++ /dev/null @@ -1,59 +0,0 @@ -# -*- coding: utf-8 -""" -Steps for behavioral style tests are defined in this module. -Each step is defined by the string decorating it. -This string is used to call the step in "*.feature" file. -""" -from __future__ import unicode_literals - -import wrappers -from behave import when, then - - -@when('we save a named query') -def step_save_named_query(context): - """ - Send \ns command - """ - context.cli.sendline('\\fs foo SELECT 12345') - - -@when('we use a named query') -def step_use_named_query(context): - """ - Send \n command - """ - context.cli.sendline('\\f foo') - - -@when('we delete a named query') -def step_delete_named_query(context): - """ - Send \nd command - """ - context.cli.sendline('\\fd foo') - - -@then('we see the named query saved') -def step_see_named_query_saved(context): - """ - Wait to see query saved. - """ - wrappers.expect_exact(context, 'Saved.', timeout=1) - - -@then('we see the named query executed') -def step_see_named_query_executed(context): - """ - Wait to see select output. - """ - wrappers.expect_exact(context, '12345', timeout=1) - wrappers.expect_exact(context, 'SELECT 1', timeout=1) - - -@then('we see the named query deleted') -def step_see_named_query_deleted(context): - """ - Wait to see query deleted. - """ - wrappers.expect_exact(context, 'foo: Deleted', timeout=1) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py deleted file mode 100644 index 790b2476..00000000 --- a/tests/features/steps/specials.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -""" -Steps for behavioral style tests are defined in this module. -Each step is defined by the string decorating it. -This string is used to call the step in "*.feature" file. -""" -from __future__ import unicode_literals - -import wrappers -from behave import when, then - - -@when('we refresh completions') -def step_refresh_completions(context): - """ - Send refresh command. - """ - context.cli.sendline('rehash') - - -@then('we see completions refresh started') -def step_see_refresh_started(context): - """ - Wait to see refresh output. - """ - wrappers.expect_exact(context, 'Auto-completion refresh started in the background', timeout=2) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py deleted file mode 100644 index eac7c830..00000000 --- a/tests/features/steps/wrappers.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -from __future__ import unicode_literals - -import re - - -def expect_exact(context, expected, timeout): - try: - context.cli.expect_exact(expected, timeout=timeout) - except: - # Strip color codes out of the output. - actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', '', context.cli.before) - raise Exception('Expected:\n---\n{0!r}\n---\n\nActual:\n---\n{1!r}\n---'.format( - expected, - actual)) diff --git a/tests/test_expanded.py b/tests/test_expanded.py deleted file mode 100644 index 7233e91c..00000000 --- a/tests/test_expanded.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Test the vertical, expanded table formatter.""" -from textwrap import dedent - -from mycli.output_formatter.expanded import expanded_table -from mycli.encodingutils import text_type - - -def test_expanded_table_renders(): - results = [('hello', text_type(123)), ('world', text_type(456))] - - expected = dedent("""\ - ***************************[ 1. row ]*************************** - name | hello - age | 123 - ***************************[ 2. row ]*************************** - name | world - age | 456 - """) - assert expected == expanded_table(results, ('name', 'age')) diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 1271f18d..00000000 --- a/tests/test_main.py +++ /dev/null @@ -1,172 +0,0 @@ -import os - -import click -from click.testing import CliRunner - -from mycli.main import (cli, confirm_destructive_query, - is_destructive, query_starts_with, queries_start_with, - thanks_picker, PACKAGE_ROOT) -from utils import USER, HOST, PORT, PASSWORD, dbtest, run - -from textwrap import dedent - -try: - text_type = basestring -except NameError: - text_type = str - -CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, - '--password', PASSWORD, '_test_db'] - -@dbtest -def test_execute_arg(executor): - run(executor, 'create table test (a text)') - run(executor, 'insert into test values("abc")') - - sql = 'select * from test;' - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) - - assert result.exit_code == 0 - assert 'abc' in result.output - - result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) - - assert result.exit_code == 0 - assert 'abc' in result.output - - expected = 'a\nabc\n' - - assert expected in result.output - - -@dbtest -def test_execute_arg_with_table(executor): - run(executor, 'create table test (a text)') - run(executor, 'insert into test values("abc")') - - sql = 'select * from test;' - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) - expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' - - assert result.exit_code == 0 - assert expected in result.output - - -@dbtest -def test_execute_arg_with_csv(executor): - run(executor, 'create table test (a text)') - run(executor, 'insert into test values("abc")') - - sql = 'select * from test;' - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) - expected = 'a\nabc\n\n' - - assert result.exit_code == 0 - assert expected in result.output - - -@dbtest -def test_batch_mode(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') - - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) - - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS, input=sql) - - assert result.exit_code == 0 - assert 'count(*)\n3\n\na\nabc\n' in result.output - -@dbtest -def test_batch_mode_table(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') - - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) - - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) - - expected = (dedent("""\ - +----------+ - | count(*) | - +----------+ - | 3 | - +----------+ - +-----+ - | a | - +-----+ - | abc | - +-----+""")) - - assert result.exit_code == 0 - assert expected in result.output - -@dbtest -def test_batch_mode_csv(executor): - run(executor, '''create table test(a text, b text)''') - run(executor, '''insert into test (a, b) values('abc', 'def'), ('ghi', 'jkl')''') - - sql = 'select * from test;' - - runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) - - expected = 'a,b\nabc,def\nghi,jkl\n' - - assert result.exit_code == 0 - assert expected in result.output - -def test_query_starts_with(executor): - query = 'USE test;' - assert query_starts_with(query, ('use', )) is True - - query = 'DROP DATABASE test;' - assert query_starts_with(query, ('use', )) is False - -def test_query_starts_with_comment(executor): - query = '# comment\nUSE test;' - assert query_starts_with(query, ('use', )) is True - -def test_queries_start_with(executor): - sql = ( - '# comment\n' - 'show databases;' - 'use foo;' - ) - assert queries_start_with(sql, ('show', 'select')) is True - assert queries_start_with(sql, ('use', 'drop')) is True - assert queries_start_with(sql, ('delete', 'update')) is False - -def test_is_destructive(executor): - sql = ( - 'use test;\n' - 'show databases;\n' - 'drop database foo;' - ) - assert is_destructive(sql) is True - -def test_confirm_destructive_query_notty(executor): - stdin = click.get_text_stream('stdin') - assert stdin.isatty() is False - - sql = 'drop database foo;' - assert confirm_destructive_query(sql) is None - -def test_thanks_picker_utf8(): - project_root = os.path.dirname(PACKAGE_ROOT) - author_file = os.path.join(project_root, 'AUTHORS') - sponsor_file = os.path.join(project_root, 'SPONSORS') - - name = thanks_picker((author_file, sponsor_file)) - assert isinstance(name, text_type) diff --git a/tests/test_output_formatter.py b/tests/test_output_formatter.py deleted file mode 100644 index 9844c191..00000000 --- a/tests/test_output_formatter.py +++ /dev/null @@ -1,160 +0,0 @@ -# -*- coding: utf-8 -*- -"""Test the generic output formatter interface.""" - -from __future__ import unicode_literals -from decimal import Decimal -from textwrap import dedent - -from mycli.output_formatter.preprocessors import (align_decimals, - bytes_to_string, - convert_to_string, - quote_whitespaces, - override_missing_value, - to_string) -from mycli.output_formatter.output_formatter import OutputFormatter -from mycli.output_formatter import delimited_output_adapter -from mycli.output_formatter import tabulate_adapter -from mycli.output_formatter import terminaltables_adapter - - -def test_to_string(): - """Test the *output_formatter.to_string()* function.""" - assert 'a' == to_string('a') - assert 'a' == to_string(b'a') - assert '1' == to_string(1) - assert '1.23' == to_string(1.23) - - -def test_convert_to_string(): - """Test the *output_formatter.convert_to_string()* function.""" - data = [[1, 'John'], [2, 'Jill']] - headers = [0, 'name'] - expected = ([['1', 'John'], ['2', 'Jill']], ['0', 'name']) - - assert expected == convert_to_string(data, headers) - - -def test_override_missing_values(): - """Test the *output_formatter.override_missing_values()* function.""" - data = [[1, None], [2, 'Jill']] - headers = [0, 'name'] - expected = ([[1, ''], [2, 'Jill']], [0, 'name']) - - assert expected == override_missing_value(data, headers, - missing_value='') - - -def test_bytes_to_string(): - """Test the *output_formatter.bytes_to_string()* function.""" - data = [[1, 'John'], [2, b'Jill']] - headers = [0, 'name'] - expected = ([[1, 'John'], [2, 'Jill']], [0, 'name']) - - assert expected == bytes_to_string(data, headers) - - -def test_align_decimals(): - """Test the *align_decimals()* function.""" - data = [[Decimal('200'), Decimal('1')], [ - Decimal('1.00002'), Decimal('1.0')]] - headers = ['num1', 'num2'] - expected = ([['200', '1'], [' 1.00002', '1.0']], ['num1', 'num2']) - - assert expected == align_decimals(data, headers) - - -def test_align_decimals_empty_result(): - """Test *align_decimals()* with no results.""" - data = [] - headers = ['num1', 'num2'] - expected = ([], ['num1', 'num2']) - - assert expected == align_decimals(data, headers) - - -def test_quote_whitespaces(): - """Test the *quote_whitespaces()* function.""" - data = [[" before", "after "], [" both ", "none"]] - headers = ['h1', 'h2'] - expected = ([["' before'", "'after '"], ["' both '", "'none'"]], - ['h1', 'h2']) - - assert expected == quote_whitespaces(data, headers) - - -def test_quote_whitespaces_empty_result(): - """Test the *quote_whitespaces()* function with no results.""" - data = [] - headers = ['h1', 'h2'] - expected = ([], ['h1', 'h2']) - - assert expected == quote_whitespaces(data, headers) - - -def test_tabulate_wrapper(): - """Test the *output_formatter.tabulate_wrapper()* function.""" - data = [['abc', 1], ['d', 456]] - headers = ['letters', 'number'] - output = tabulate_adapter.adapter(data, headers, table_format='psql') - assert output == dedent('''\ - +-----------+----------+ - | letters | number | - |-----------+----------| - | abc | 1 | - | d | 456 | - +-----------+----------+''') - - -def test_csv_wrapper(): - """Test the *output_formatter.csv_wrapper()* function.""" - # Test comma-delimited output. - data = [['abc', 1], ['d', 456]] - headers = ['letters', 'number'] - output = delimited_output_adapter.adapter(data, headers) - assert output == dedent('''\ - letters,number\r\n\ - abc,1\r\n\ - d,456\r\n''') - - # Test tab-delimited output. - data = [['abc', 1], ['d', 456]] - headers = ['letters', 'number'] - output = delimited_output_adapter.adapter( - data, headers, table_format='tsv') - assert output == dedent('''\ - letters\tnumber\r\n\ - abc\t1\r\n\ - d\t456\r\n''') - - -def test_terminal_tables_wrapper(): - """Test the *output_formatter.terminal_tables_wrapper()* function.""" - data = [['abc', 1], ['d', 456]] - headers = ['letters', 'number'] - output = terminaltables_adapter.adapter( - data, headers, table_format='ascii') - assert output == dedent('''\ - +---------+--------+ - | letters | number | - +---------+--------+ - | abc | 1 | - | d | 456 | - +---------+--------+''') - - -def test_output_formatter(): - """Test the *output_formatter.OutputFormatter* class.""" - data = [['abc', Decimal(1)], ['defg', Decimal('11.1')], - ['hi', Decimal('1.1')]] - headers = ['text', 'numeric'] - expected = dedent('''\ - +------+---------+ - | text | numeric | - +------+---------+ - | abc | 1 | - | defg | 11.1 | - | hi | 1.1 | - +------+---------+''') - - assert expected == OutputFormatter().format_output(data, headers, - format_name='ascii') diff --git a/tests/test_special_iocommands.py b/tests/test_special_iocommands.py deleted file mode 100644 index 2ed2dbad..00000000 --- a/tests/test_special_iocommands.py +++ /dev/null @@ -1,78 +0,0 @@ -# coding: utf-8 -import os -import stat -import tempfile - -import pytest - -import mycli.packages.special -import utils - - -def test_set_get_pager(): - mycli.packages.special.set_pager_enabled(True) - assert mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager_enabled(False) - assert not mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager('less') - assert os.environ['PAGER'] == "less" - mycli.packages.special.set_pager(False) - assert os.environ['PAGER'] == "less" - del os.environ['PAGER'] - mycli.packages.special.set_pager(False) - mycli.packages.special.disable_pager() - assert not mycli.packages.special.is_pager_enabled() - -def test_set_get_timing(): - mycli.packages.special.set_timing_enabled(True) - assert mycli.packages.special.is_timing_enabled() - mycli.packages.special.set_timing_enabled(False) - assert not mycli.packages.special.is_timing_enabled() - -def test_set_get_expanded_output(): - mycli.packages.special.set_expanded_output(True) - assert mycli.packages.special.is_expanded_output() - mycli.packages.special.set_expanded_output(False) - assert not mycli.packages.special.is_expanded_output() - -def test_editor_command(): - assert mycli.packages.special.editor_command(r'hello\e') - assert mycli.packages.special.editor_command(r'\ehello') - assert not mycli.packages.special.editor_command(r'hello') - - assert mycli.packages.special.get_filename(r'\e filename') == "filename" - - os.environ['EDITOR'] = 'true' - mycli.packages.special.open_external_editor(r'select 1') == "select 1" - -def test_tee_command(): - mycli.packages.special.write_tee(u"hello world") # write without file set - with tempfile.NamedTemporaryFile() as f: - mycli.packages.special.execute(None, u"tee "+f.name) - mycli.packages.special.write_tee(u"hello world") - assert f.read() == b"hello world\n" - - mycli.packages.special.execute(None, u"tee -o "+f.name) - mycli.packages.special.write_tee(u"hello world") - f.seek(0) - assert f.read() == b"hello world\n" - - mycli.packages.special.execute(None, u"notee") - mycli.packages.special.write_tee(u"hello world") - f.seek(0) - assert f.read() == b"hello world\n" - -def test_tee_command_error(): - with pytest.raises(TypeError): - mycli.packages.special.execute(None, 'tee') - - with pytest.raises(OSError): - with tempfile.NamedTemporaryFile() as f: - os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) - mycli.packages.special.execute(None, 'tee {}'.format(f.name)) - -def test_favorite_query(): - with utils.db_connection().cursor() as cur: - query = u'select "✔"' - mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) - assert next(mycli.packages.special.execute(cur, u'\\f check'))[0] == "> " + query diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py deleted file mode 100644 index a9b5fbec..00000000 --- a/tests/test_sqlexecute.py +++ /dev/null @@ -1,327 +0,0 @@ -# coding=UTF-8 - -import pytest -import pymysql -import os -from textwrap import dedent -from utils import run, dbtest, set_expanded_output - - -@dbtest -def test_conn(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - results = run(executor, '''select * from test''', join=True) - assert results == dedent("""\ - +-----+ - | a | - +-----+ - | abc | - +-----+ - 1 row in set""") - -@dbtest -def test_bools(executor): - run(executor, '''create table test(a boolean)''') - run(executor, '''insert into test values(True)''') - results = run(executor, '''select * from test''', join=True) - assert results == dedent("""\ - +---+ - | a | - +---+ - | 1 | - +---+ - 1 row in set""") - -@dbtest -def test_binary(executor): - run(executor, '''create table bt(geom linestring NOT NULL)''') - run(executor, '''INSERT INTO bt VALUES (GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));''') - results = run(executor, '''select * from bt''', join=True) - assert results == dedent("""\ - +----------------------------------------------------------------------------------------------+ - | geom | - +----------------------------------------------------------------------------------------------+ - | 0x00000000010200000002000000397f130a11185d4034f44f70b1de43400000000000185d40423ee8d9acde4340 | - +----------------------------------------------------------------------------------------------+ - 1 row in set""") - -@dbtest -def test_binary_expanded(executor): - run(executor, '''create table bt(geom linestring NOT NULL)''') - run(executor, '''INSERT INTO bt VALUES (GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));''') - results = run(executor, '''select * from bt\G''', join=True) - assert results == dedent("""\ - ***************************[ 1. row ]*************************** - geom | 0x00000000010200000002000000397f130a11185d4034f44f70b1de43400000000000185d40423ee8d9acde4340 - - 1 row in set""") - -@dbtest -def test_table_and_columns_query(executor): - run(executor, "create table a(x text, y text)") - run(executor, "create table b(z text)") - - assert set(executor.tables()) == set([('a',), ('b',)]) - assert set(executor.table_columns()) == set( - [('a', 'x'), ('a', 'y'), ('b', 'z')]) - -@dbtest -def test_database_list(executor): - databases = executor.databases() - assert '_test_db' in databases - -@dbtest -def test_invalid_syntax(executor): - with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, 'invalid syntax!') - assert 'You have an error in your SQL syntax;' in str(excinfo.value) - -@dbtest -def test_invalid_column_name(executor): - with pytest.raises(pymysql.InternalError) as excinfo: - run(executor, 'select invalid command') - assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) - -@dbtest -def test_unicode_support_in_output(executor): - run(executor, "create table unicodechars(t text)") - run(executor, u"insert into unicodechars (t) values ('é')") - - # See issue #24, this raises an exception without proper handling - assert u'é' in run(executor, u"select * from unicodechars", join=True) - -@dbtest -def test_expanded_output(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - results = run(executor, '''select * from test\G''', join=True) - - expected_results = set([ - dedent("""\ - -[ RECORD 0 ] - a | abc - - 1 row in set"""), - dedent("""\ - ***************************[ 1. row ]*************************** - a | abc - - 1 row in set"""), - ]) - - assert results in expected_results - -@dbtest -def test_multiple_queries_same_line(executor): - result = run(executor, "select 'foo'; select 'bar'") - assert len(result) == 4 # 2 for the results and 2 more for status messages. - assert "foo" in result[0] - assert "bar" in result[2] - -@dbtest -def test_multiple_queries_same_line_syntaxerror(executor): - with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, "select 'foo'; invalid syntax") - assert 'You have an error in your SQL syntax;' in str(excinfo.value) - -@dbtest -def test_favorite_query(executor): - set_expanded_output(False) - run(executor, "create table test(a text)") - run(executor, "insert into test values('abc')") - run(executor, "insert into test values('def')") - - results = run(executor, "\\fs test-a select * from test where a like 'a%'") - assert results == ['Saved.'] - - results = run(executor, "\\f test-a", join=True) - assert results == dedent("""\ - > select * from test where a like 'a%' - +-----+ - | a | - +-----+ - | abc | - +-----+""") - - results = run(executor, "\\fd test-a") - assert results == ['test-a: Deleted'] - -@dbtest -def test_favorite_query_multiple_statement(executor): - set_expanded_output(False) - run(executor, "create table test(a text)") - run(executor, "insert into test values('abc')") - run(executor, "insert into test values('def')") - - results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " - "select * from test where a like 'd%'") - assert results == ['Saved.'] - - results = run(executor, "\\f test-ad", join=True) - assert results == dedent("""\ - > select * from test where a like 'a%' - +-----+ - | a | - +-----+ - | abc | - +-----+ - > select * from test where a like 'd%' - +-----+ - | a | - +-----+ - | def | - +-----+""") - - results = run(executor, "\\fd test-ad") - assert results == ['test-ad: Deleted'] - -@dbtest -def test_favorite_query_expanded_output(executor): - set_expanded_output(False) - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - - results = run(executor, "\\fs test-ae select * from test") - assert results == ['Saved.'] - - results = run(executor, "\\f test-ae \G", join=True) - - expected_results = set([ - dedent("""\ - > select * from test - -[ RECORD 0 ] - a | abc - """), - dedent("""\ - > select * from test - ***************************[ 1. row ]*************************** - a | abc - """), - ]) - set_expanded_output(False) - - assert results in expected_results - - results = run(executor, "\\fd test-ae") - assert results == ['test-ae: Deleted'] - -@dbtest -def test_special_command(executor): - results = run(executor, '\\?') - expected_line = u'\n| help' - assert len(results) == 1 - assert expected_line in results[0] - -@dbtest -def test_cd_command_without_a_folder_name(executor): - results = run(executor, 'system cd') - expected_line = 'No folder name was provided.' - assert len(results) == 1 - assert expected_line in results[0] - -@dbtest -def test_system_command_not_found(executor): - results = run(executor, 'system xyz') - assert len(results) == 1 - expected_line = 'OSError:' - assert expected_line in results[0] - -@dbtest -def test_system_command_output(executor): - test_file_path = os.path.join(os.path.abspath('.'), 'tests/test.txt') - results = run(executor, 'system cat {0}'.format(test_file_path)) - assert len(results) == 1 - expected_line = u'mycli rocks!\n' - assert expected_line == results[0] - -@dbtest -def test_cd_command_current_dir(executor): - tests_path = os.path.join(os.path.abspath('.'), 'tests') - results = run(executor, 'system cd {0}'.format(tests_path)) - assert os.getcwd() == tests_path - -@dbtest -def test_unicode_support(executor): - assert u'日本語' in run(executor, u"SELECT '日本語' AS japanese;", join=True) - -@dbtest -def test_favorite_query_multiline_statement(executor): - set_expanded_output(False) - run(executor, "create table test(a text)") - run(executor, "insert into test values('abc')") - run(executor, "insert into test values('def')") - - results = run(executor, "\\fs test-ad select * from test where a like 'a%';\n" - "select * from test where a like 'd%'") - assert results == ['Saved.'] - - results = run(executor, "\\f test-ad", join=True) - assert results == dedent("""\ - > select * from test where a like 'a%' - +-----+ - | a | - +-----+ - | abc | - +-----+ - > select * from test where a like 'd%' - +-----+ - | a | - +-----+ - | def | - +-----+""") - - results = run(executor, "\\fd test-ad") - assert results == ['test-ad: Deleted'] - -@dbtest -def test_timestamp_null(executor): - run(executor, '''create table ts_null(a timestamp)''') - run(executor, '''insert into ts_null values(0)''') - results = run(executor, '''select * from ts_null''', join=True) - assert results == dedent("""\ - +---------------------+ - | a | - +---------------------+ - | 0000-00-00 00:00:00 | - +---------------------+ - 1 row in set""") - -@dbtest -def test_datetime_null(executor): - run(executor, '''create table dt_null(a datetime)''') - run(executor, '''insert into dt_null values(0)''') - results = run(executor, '''select * from dt_null''', join=True) - assert results == dedent("""\ - +---------------------+ - | a | - +---------------------+ - | 0000-00-00 00:00:00 | - +---------------------+ - 1 row in set""") - -@dbtest -def test_date_null(executor): - run(executor, '''create table date_null(a date)''') - run(executor, '''insert into date_null values(0)''') - results = run(executor, '''select * from date_null''', join=True) - assert results == dedent("""\ - +------------+ - | a | - +------------+ - | 0000-00-00 | - +------------+ - 1 row in set""") - -@dbtest -def test_time_null(executor): - run(executor, '''create table time_null(a time)''') - run(executor, '''insert into time_null values(0)''') - results = run(executor, '''select * from time_null''', join=True) - assert results == dedent("""\ - +----------+ - | a | - +----------+ - | 00:00:00 | - +----------+ - 1 row in set""") diff --git a/tests/test_tabulate.py b/tests/test_tabulate.py deleted file mode 100644 index ae7c25ce..00000000 --- a/tests/test_tabulate.py +++ /dev/null @@ -1,17 +0,0 @@ -from textwrap import dedent - -from mycli.packages import tabulate - -tabulate.PRESERVE_WHITESPACE = True - - -def test_dont_strip_leading_whitespace(): - data = [[' abc']] - headers = ['xyz'] - tbl = tabulate.tabulate(data, headers, tablefmt='psql') - assert tbl == dedent(''' - +---------+ - | xyz | - |---------| - | abc | - +---------+ ''').strip() diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index b29e5e07..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -from os import getenv - -import pymysql -import pytest - -from mycli.main import MyCli, special - -PASSWORD = getenv('PYTEST_PASSWORD') -USER = getenv('PYTEST_USER', 'root') -HOST = getenv('PYTEST_HOST', 'localhost') -PORT = getenv('PYTEST_PORT', 3306) -CHARSET = getenv('PYTEST_CHARSET', 'utf8') - -def db_connection(dbname=None): - conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, - charset=CHARSET, - local_infile=False) - conn.autocommit = True - return conn - -try: - db_connection() - CAN_CONNECT_TO_DB = True -except: - CAN_CONNECT_TO_DB = False - -dbtest = pytest.mark.skipif( - not CAN_CONNECT_TO_DB, - reason="Need a mysql instance at localhost accessible by user 'root'") - -def create_db(dbname): - with db_connection().cursor() as cur: - try: - cur.execute('''DROP DATABASE IF EXISTS _test_db''') - cur.execute('''CREATE DATABASE _test_db''') - except: - pass - -def run(executor, sql, join=False): - " Return string output for the sql to be run " - result = [] - - # TODO: this needs to go away. `run()` should not test formatted output. - # It should test raw results. - mycli = MyCli() - for title, rows, headers, status in executor.run(sql): - result.extend(mycli.format_output(title, rows, headers, status, - special.is_expanded_output())) - - if join: - result = '\n'.join(result) - return result - -def set_expanded_output(is_expanded): - """ Pass-through for the tests """ - return special.set_expanded_output(is_expanded) diff --git a/tox.ini b/tox.ini index 8d578d8f..612e8b7f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,15 @@ [tox] -envlist = py27, py33, py34, py35, py36 +envlist = py36, py37, py38 + [testenv] deps = pytest mock -commands = py.test --doctest-modules --doctest-ignore-import-errors + pexpect + behave + coverage +commands = python setup.py test +passenv = PYTEST_HOST + PYTEST_USER + PYTEST_PASSWORD + PYTEST_PORT + PYTEST_CHARSET