diff --git a/.gitignore b/.gitignore index 1ea79bb..c998d8c 100644 --- a/.gitignore +++ b/.gitignore @@ -170,5 +170,5 @@ cython_debug/ # Demo file flagged/ *.sh -!exec_filter.sh +!sanitize.sh logs/ diff --git a/README-SC2INST.md b/README-SC2INST.md new file mode 100644 index 0000000..b5a9078 --- /dev/null +++ b/README-SC2INST.md @@ -0,0 +1,328 @@ +# StarCoder2-Instruct: Fully Transparent and Permissive Self-Alignment for Code Generation + +

+ ⭐️ About + | πŸš€ Quick start + | πŸ“š Data generation + | πŸ§‘β€πŸ’» Training + | πŸ“Š Evaluation + | ⚠️ Limitations +

+ +![Banner](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/banner.png) + + + +## About + +We introduce StarCoder2-15B-Instruct-v0.1, the very first entirely self-aligned code Large Language Model (LLM) trained with a fully permissive and transparent pipeline. Our open-source pipeline uses StarCoder2-15B to generate thousands of instruction-response pairs, which are then used to fine-tune StarCoder-15B itself without any human annotations or distilled data from huge and proprietary LLMs. + +- **Model:** [bigcode/starcoder2-15b-instruct-v0.1](https://huggingface.co/bigcode/starcoder2-instruct-15b-v0.1) +- **Code:** [bigcode-project/starcoder2-self-align](https://github.com/bigcode-project/starcoder2-self-align) +- **Dataset:** [bigcode/self-oss-instruct-sc2-exec-filter-50k](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k/) +- **Authors:** +[Yuxiang Wei](https://yuxiang.cs.illinois.edu), +[Federico Cassano](https://federico.codes/), +[Jiawei Liu](https://jw-liu.xyz), +[Yifeng Ding](https://yifeng-ding.com), +[Naman Jain](https://naman-ntc.github.io), +[Harm de Vries](https://www.harmdevries.com), +[Leandro von Werra](https://twitter.com/lvwerra), +[Arjun Guha](https://www.khoury.northeastern.edu/home/arjunguha/main/home/), +[Lingming Zhang](https://lingming.cs.illinois.edu). + +![self-alignment pipeline](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/method.png) + +## Quick start + +Here is an example to get started with StarCoder2-15B-Instruct-v0.1 using the [transformers](https://huggingface.co/docs/transformers/index) library: + +```python +import transformers +import torch + +pipeline = transformers.pipeline( + model="bigcode/starcoder2-15b-instruct-v0.1", + task="text-generation", + torch_dtype=torch.bfloat16, + device_map="auto", +) + +def respond(instruction: str, response_prefix: str) -> str: + messages = [{"role": "user", "content": instruction}] + prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False) + prompt += response_prefix + + teminators = [ + pipeline.tokenizer.eos_token_id, + pipeline.tokenizer.convert_tokens_to_ids("###"), + ] + + result = pipeline( + prompt, + max_length=256, + num_return_sequences=1, + do_sample=False, + eos_token_id=teminators, + pad_token_id=pipeline.tokenizer.eos_token_id, + truncation=True, + ) + response = response_prefix + result[0]["generated_text"][len(prompt) :].split("###")[0].rstrip() + return response + + +instruction = "Write a quicksort function in Python with type hints and a 'less_than' parameter for custom sorting criteria." +response_prefix = "" + +print(respond(instruction, response_prefix)) +``` + +## Data generation pipeline + +> Run `pip install -e .` first to install the package locally. Check [seed_gathering](seed_gathering/) for details on how we collected the seeds. + +By default, we use in-memory vLLM engine for data generation, but we also provide an option to use vLLM's [OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) for data generation. + +Set `CUDA_VISIBLE_DEVICES=...` to specify the GPU devices to use for the vLLM engine. + +To maximize data generation efficiency, we recommend invoking the script multiple times with different `seed_code_start_index` and `max_new_data` values, each with an vLLM engine running on a separate GPU set. For example, for a 100k seed dataset on a 2-GPU machine, you can have 2 processes each generating 50k samples by setting `CUDA_VISIBLE_DEVICES=0 --seed_code_start_index 0 --max_new_data 50000` and `CUDA_VISIBLE_DEVICES=1 --seed_code_start_index 50000 --max_new_data 50000`. + +
+ +Click to see how to run with vLLM's OpenAI compatible API + +To do so, make sure the vLLM server is running, and the associated `openai` environment variables are set. + +For example, you can start an vLLM server with `docker`: + +```shell +docker run --gpus '"device=0"' \ + -v $HF_HOME:/root/.cache/huggingface \ + -p 10000:8000 \ + --ipc=host \ + vllm/vllm-openai:v0.3.3 \ + --model bigcode/starcoder2-15b \ + --tensor-parallel-size 1 --dtype bfloat16 +``` + +And then set the environment variables as follows: + +```shell +export OPENAI_API_KEY="EMPTY" +export OPENAI_BASE_URL="http://localhost:10000/v1/" +``` + +You will also need to set `--use_vllm_server True` in the following commands. + +
+ +
+ +Snippet to concepts generation + +```shell +MODEL=bigcode/starcoder2-15b +MAX_NEW_DATA=1000000 +python src/star_align/self_ossinstruct.py \ + --use_vllm_server False \ + --instruct_mode "S->C" \ + --seed_data_files /path/to/seeds.jsonl \ + --max_new_data $MAX_NEW_DATA \ + --tag concept_gen \ + --temperature 0.7 \ + --seed_code_start_index 0 \ + --model $MODEL \ + --num_fewshots 8 \ + --num_batched_requests 2000 \ + --num_sample_per_request 1 +``` + +
+ +
+ +Concepts to instruction generation + +```shell +MODEL=bigcode/starcoder2-15b +MAX_NEW_DATA=1000000 +python src/star_align/self_ossinstruct.py \ + --instruct_mode "C->I" \ + --seed_data_files /path/to/concepts.jsonl \ + --max_new_data $MAX_NEW_DATA \ + --tag instruction_gen \ + --temperature 0.7 \ + --seed_code_start_index 0 \ + --model $MODEL \ + --num_fewshots 8 \ + --num_sample_per_request 1 \ + --num_batched_request 2000 +``` + +
+ +
+ +Instruction to response (with self-validation code) generation + +```shell +MODEL=bigcode/starcoder2-15b +MAX_NEW_DATA=1000000 +python src/star_align/self_ossinstruct.py \ + --instruct_mode "I->R" \ + --seed_data_files path/to/instructions.jsonl \ + --max_new_data $MAX_NEW_DATA \ + --tag response_gen \ + --seed_code_start_index 0 \ + --model $MODEL \ + --num_fewshots 1 \ + --num_batched_request 500 \ + --num_sample_per_request 10 \ + --temperature 0.7 +``` + +
+ +
+ +Execution filter + +> **Warning:** Though we implemented reliability guards, it is highly recommended to run execution in a sandbox environment we provided. + + +To use the Docker container for executing code, you will first need to `git submodule update --init --recursive` to clone the server, then run: + +```shell +pushd ./src/star_align/code_exec_server +./pull_and_run.sh +popd +python src/star_align/execution_filter.py \ + --response_paths /path/to/response.jsonl \ + --result_path /path/to/filtered.jsonl \ + --max_batched_tasks 10000 \ + --container_server http://127.0.0.1:8000 +``` + +Execution filter will produce a flattened list of JSONL entries with a `pass` field indicating whether the execution passed or not. **It also incrementally dumps the results and can load a cached partial data file.** You can recover an execution with: + +```shell +python src/star_align/execution_filter.py \ + --response_paths /path/to/response.jsonl* \ + --cache_paths /path/to/filtered.jsonl* \ + --result_path /path/to/filtered-1.jsonl \ + --max_batched_tasks 10000 \ + --container_server http://127.0.0.1:8000 +``` + +Note that sometimes execution can lead to significant slowdowns due to excessive resource consumption. To alleviate this, you can limit the docker's cpu usage (e.g., `docker run --cpuset-cpus="0-31"`). You can also do: + +```shell +# For example, you can set the command to be `sudo pkill -f '/tmp/codeexec'` +export CLEANUP_COMMAND="the command to execute after each batch" +python src/star_align/execution_filter.py... +``` + +Also, the container connection may be lost during execution. In this case, you can just leverage the caching mechanism described above to re-run the script. + +
+ +
+ +Data sanitization and selection + +```shell +# Uncomment to do decontamination +# export MBPP_PATH="/path/to/mbpp.jsonl" +# export DS1000_PATH="/path/to/ds1000_data" +# export DECONTAMINATION=1 +./sanitize.sh /path/to/exec-filtered.jsonl /path/to/sanitized.jsonl +``` + +
+ +## Training Details + +> Run `pip install -e .` first to install the package locally. And install [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up the training. + +### Hyperparameters + +- **Optimizer:** Adafactor +- **Learning rate:** 1e-5 +- **Epoch:** 4 +- **Batch size:** 64 +- **Warmup ratio:** 0.05 +- **Scheduler:** Linear +- **Sequence length:** 1280 +- **Dropout**: Not applied + +### Hardware + +1 x NVIDIA A100 80GB. Yes, you just need one A100 to finetune StarCoder2-15B! + +### Script + +The following script finetunes StarCoder2-15B-Instruct-v0.1 from the base StarCoder2-15B model. `/path/to/dataset.jsonl` is the JSONL format of the [50k dataset](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k) we generated. You can dump the dataset to JSONL to fit the training script. + +
+ +Click to see the training script + +NOTE: StarCoder2-15B sets dropout values to 0.1 by default. We did not apply dropout in finetuning and thus set the them to 0.0. + +```shell +MODEL_KEY=bigcode/starcoder2-15b +LR=1e-5 +EPOCH=4 +SEQ_LEN=1280 +WARMUP_RATIO=0.05 +OUTPUT_DIR=/path/to/output_model +DATASET_FILE=/path/to/50k-dataset.jsonl +accelerate launch -m star_align.train \ + --model_key $MODEL_KEY \ + --model_name_or_path $MODEL_KEY \ + --use_flash_attention True \ + --datafile_paths $DATASET_FILE \ + --output_dir $OUTPUT_DIR \ + --bf16 True \ + --num_train_epochs $EPOCH \ + --max_training_seq_length $SEQ_LEN \ + --pad_to_max_length False \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --group_by_length False \ + --ddp_find_unused_parameters False \ + --logging_steps 1 \ + --log_level info \ + --optim adafactor \ + --max_grad_norm -1 \ + --warmup_ratio $WARMUP_RATIO \ + --learning_rate $LR \ + --lr_scheduler_type linear \ + --attention_dropout 0.0 \ + --residual_dropout 0.0 \ + --embedding_dropout 0.0 +``` + +
+ +## Evaluation on EvalPlus, LiveCodeBench, and DS-1000 + +> Check [evaluation](evaluation/) for more details. + +![EvalPlus](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/evalplus.png) + +![LiveCodeBench and DS-1000](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/lcb-ds1000.png) + +## Bias, Risks, and Limitations + +StarCoder2-15B-Instruct-v0.1 is primarily finetuned for Python code generation tasks that can be verified through execution, which may lead to certain biases and limitations. For example, the model might not adhere strictly to instructions that dictate the output format. In these situations, it's beneficial to provide a **response prefix** or a **one-shot example** to steer the model’s output. Additionally, the model may have limitations with other programming languages and out-of-domain coding tasks. + +The model also inherits the bias, risks, and limitations from its base StarCoder2-15B model. For more information, please refer to the [StarCoder2-15B model card](https://huggingface.co/bigcode/starcoder2-15b). diff --git a/README.md b/README.md index aa8d8f0..d107c10 100644 --- a/README.md +++ b/README.md @@ -1,275 +1,61 @@ -# StarCoder2-Instruct: Fully Transparent and Permissive Self-Alignment for Code Generation +# SelfCodeAlign: Self-Alignment for Code Generation

- ⭐️ About - | πŸš€ Quick start - | πŸ“š Data generation - | πŸ§‘β€πŸ’» Training - | πŸ“Š Evaluation - | ⚠️ Limitations + Paper + + + +

-![Banner](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/banner.png) - - +

+ 🧐 About + | ⭐️ StarCoder2-Instruct + | πŸ“ Citation + + + + + +

## About -We introduce StarCoder2-15B-Instruct-v0.1, the very first entirely self-aligned code Large Language Model (LLM) trained with a fully permissive and transparent pipeline. Our open-source pipeline uses StarCoder2-15B to generate thousands of instruction-response pairs, which are then used to fine-tune StarCoder-15B itself without any human annotations or distilled data from huge and proprietary LLMs. +**SelfCodeAlign** is the first fully open and transparent pipeline that enhances a code language model without relying on human annotations or distilled data from large, proprietary models. This approach led to the creation of [StarCoder2-Instruct](README-SC2INST.md), a fully transparent, permissively licensed, self-aligned code model that achieves state-of-the-art performance in coding tasks. -- **Model:** [bigcode/starcoder2-15b-instruct-v0.1](https://huggingface.co/bigcode/starcoder2-instruct-15b-v0.1) -- **Code:** [bigcode-project/starcoder2-self-align](https://github.com/bigcode-project/starcoder2-self-align) -- **Dataset:** [bigcode/self-oss-instruct-sc2-exec-filter-50k](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k/) -- **Authors:** +**Authors:** [Yuxiang Wei](https://yuxiang.cs.illinois.edu), [Federico Cassano](https://federico.codes/), [Jiawei Liu](https://jw-liu.xyz), [Yifeng Ding](https://yifeng-ding.com), [Naman Jain](https://naman-ntc.github.io), +[Zachary Mueller](https://muellerzr.github.io), [Harm de Vries](https://www.harmdevries.com), [Leandro von Werra](https://twitter.com/lvwerra), -[Arjun Guha](https://www.khoury.northeastern.edu/home/arjunguha/main/home/), +[Arjun Guha](https://www.khoury.northeastern.edu/home/arjunguha/main/homehttps://www.khoury.northeastern.edu/home/arjunguha/main/home//), [Lingming Zhang](https://lingming.cs.illinois.edu). -![self-alignment pipeline](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/method.png) - -## Quick start - -Here is an example to get started with StarCoder2-15B-Instruct-v0.1 using the [transformers](https://huggingface.co/docs/transformers/index) library: - -```python -import transformers -import torch - -pipeline = transformers.pipeline( - model="bigcode/starcoder2-15b-instruct-v0.1", - task="text-generation", - torch_dtype=torch.bfloat16, - device_map="auto", -) - -def respond(instruction: str, response_prefix: str) -> str: - messages = [{"role": "user", "content": instruction}] - prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False) - prompt += response_prefix - - teminators = [ - pipeline.tokenizer.eos_token_id, - pipeline.tokenizer.convert_tokens_to_ids("###"), - ] - - result = pipeline( - prompt, - max_length=256, - num_return_sequences=1, - do_sample=False, - eos_token_id=teminators, - pad_token_id=pipeline.tokenizer.eos_token_id, - truncation=True, - ) - response = response_prefix + result[0]["generated_text"][len(prompt) :].split("###")[0].rstrip() - return response - - -instruction = "Write a quicksort function in Python with type hints and a 'less_than' parameter for custom sorting criteria." -response_prefix = "" - -print(respond(instruction, response_prefix)) -``` - -## Data generation pipeline - -> Run `pip install -e .` first to install the package locally. Check [seed_gathering](seed_gathering/) for details on how we collected the seeds. - -We used vLLM's [OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) for data generation. So, before running the following commands, make sure the vLLM server is running, and the associated `openai` environment variables are set. - -For example, you can start an vLLM server with `docker`: - -```shell -docker run --gpus '"device=0"' \ - -v $HF_HOME:/root/.cache/huggingface \ - -p 10000:8000 \ - --ipc=host \ - vllm/vllm-openai:v0.3.3 \ - --model bigcode/starcoder2-15b \ - --tensor-parallel-size 1 --dtype bfloat16 -``` - -And then set the environment variables as follows: - -```shell -export OPENAI_API_KEY="EMPTY" -export OPENAI_BASE_URL="http://localhost:10000/v1/" -``` - -
- -Snippet to concepts generation - -```shell -python src/star_align/self_ossinstruct.py \ - --instruct_mode "S->C" \ - --seed_data_files /path/to/seeds.jsonl \ - --max_new_data 50000 \ - --tag concept_gen \ - --temperature 0.7 \ - --seed_code_start_index 0 \ - --model bigcode/starcoder2-15b \ - --num_fewshots 8 \ - --num_batched_requests 32 \ - --num_sample_per_request 1 -``` - -
- -
- -Concepts to instruction generation - -```shell -python src/star_align/self_ossinstruct.py \ - --instruct_mode "C->I" \ - --seed_data_files /path/to/concepts.jsonl \ - --max_new_data 50000 \ - --tag instruction_gen \ - --temperature 0.7 \ - --seed_code_start_index 0 \ - --model bigcode/starcoder2-15b \ - --num_fewshots 8 \ - --num_sample_per_request 1 \ - --num_batched_request 32 -``` - -
- -
- -Instruction to response (with self-validation code) generation - -```shell -python src/star_align/self_ossinstruct.py \ - --instruct_mode "I->R" \ - --seed_data_files path/to/instructions.jsonl \ - --max_new_data 50000 \ - --tag response_gen \ - --seed_code_start_index 0 \ - --model bigcode/starcoder2-15b \ - --num_fewshots 1 \ - --num_batched_request 8 \ - --num_sample_per_request 10 \ - --temperature 0.7 -``` - -
- -
- -Execution filter - -> **Warning:** Though we implemented reliability guards, it is highly recommended to run execution in a sandbox environment. The command below doesn't provide sandboxing by default. - -```shell -python src/star_align/execution_filter.py --response_path /path/to/response.jsonl --result_path /path/to/filtered.jsonl -# The current implementation may cause deadlock. -# If you encounter deadlock, manually do `ps -ef | grep execution_filter` and kill the stuck process. -# Note that filtered.jsonl may contain multiple passing samples for the same instruction which needs further selection. -``` - -For using the the Docker container for executing code, you will first need to `git submodule update --init --recursive` to clone the server, then run: - -```shell -pushd ./src/star_align/code_exec_server -./build_and_run.sh -popd -python src/star_align/execution_filter.py --response_path /path/to/response.jsonl --result_path /path/to/filtered.jsonl --container_server http://127.0.0.1:8000 -``` - -
- -
- -Data sanitization and selection - -```shell -RAW=1 python src/star_align/sanitize_data.py /path/to/filtered.jsonl /path/to/sanitized.jsonl -python src/star_align/clean_data.py --data_files /path/to/sanitized.jsonl --output_file /path/to/sanitized.jsonl --diversify_func_names -SMART=1 python src/star_align/sanitize_data.py /path/to/sanitized.jsonl /path/to/sanitized.jsonl -``` - -
- -## Training Details - -> Run `pip install -e .` first to install the package locally. And install [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up the training. - -### Hyperparameters +![self-alignment pipeline](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/SelfCodeAlign.png) -- **Optimizer:** Adafactor -- **Learning rate:** 1e-5 -- **Epoch:** 4 -- **Batch size:** 64 -- **Warmup ratio:** 0.05 -- **Scheduler:** Linear -- **Sequence length:** 1280 -- **Dropout**: Not applied +## StarCoder2-Instruct -### Hardware - -1 x NVIDIA A100 80GB. Yes, you just need one A100 to finetune StarCoder2-15B! +![Banner](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/banner.png) -### Script +StarCoder2-Instruct is created with an [earlier version](https://github.com/bigcode-project/selfcodealign/tree/starcoder2-instruct) of SelfCodeAlign. It is the very first entirely self-aligned code Large Language Model (LLM) trained with a fully permissive and transparent pipeline. Our open-source pipeline uses StarCoder2-15B to generate thousands of instruction-response pairs, which are then used to fine-tune StarCoder-15B itself without any human annotations or distilled data from huge and proprietary LLMs. -The following script finetunes StarCoder2-15B-Instruct-v0.1 from the base StarCoder2-15B model. `/path/to/dataset.jsonl` is the JSONL format of the [50k dataset](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k) we generated. You can dump the dataset to JSONL to fit the training script. +- **Model:** [bigcode/starcoder2-15b-instruct-v0.1](https://huggingface.co/bigcode/starcoder2-instruct-15b-v0.1) +- **Code:** [bigcode-project/starcoder2-self-align](https://github.com/bigcode-project/starcoder2-self-align) +- **Dataset:** [bigcode/self-oss-instruct-sc2-exec-filter-50k](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k/) -
+For more details, check [README-SC2INST.md](README-SC2INST.md). -Click to see the training script +## Citation -```shell -MODEL_KEY=bigcode/starcoder2-15b -LR=1e-5 -EPOCH=4 -SEQ_LEN=1280 -WARMUP_RATIO=0.05 -OUTPUT_DIR=/path/to/output_model -DATASET_FILE=/path/to/50k-dataset.jsonl -accelerate launch -m star_align.train \ - --model_key $MODEL_KEY \ - --model_name_or_path $MODEL_KEY \ - --use_flash_attention True \ - --datafile_paths $DATASET_FILE \ - --output_dir $OUTPUT_DIR \ - --bf16 True \ - --num_train_epochs $EPOCH \ - --max_training_seq_length $SEQ_LEN \ - --pad_to_max_length False \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 64 \ - --group_by_length False \ - --ddp_find_unused_parameters False \ - --logging_steps 1 \ - --log_level info \ - --optim adafactor \ - --max_grad_norm -1 \ - --warmup_ratio $WARMUP_RATIO \ - --learning_rate $LR \ - --lr_scheduler_type linear +```bibtex +@article{wei2024selfcodealign, + title={SelfCodeAlign: Self-Alignment for Code Generation}, + author={Yuxiang Wei and Federico Cassano and Jiawei Liu and Yifeng Ding and Naman Jain and Zachary Mueller and Harm de Vries and Leandro von Werra and Arjun Guha and Lingming Zhang}, + year={2024}, + journal={arXiv preprint arXiv:2410.24198} +} ``` - -
- -## Evaluation on EvalPlus, LiveCodeBench, and DS-1000 - -> Check [evaluation](evaluation/) for more details. - -![EvalPlus](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/evalplus.png) - -![LiveCodeBench and DS-1000](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/lcb-ds1000.png) - -## Bias, Risks, and Limitations - -StarCoder2-15B-Instruct-v0.1 is primarily finetuned for Python code generation tasks that can be verified through execution, which may lead to certain biases and limitations. For example, the model might not adhere strictly to instructions that dictate the output format. In these situations, it's beneficial to provide a **response prefix** or a **one-shot example** to steer the model’s output. Additionally, the model may have limitations with other programming languages and out-of-domain coding tasks. - -The model also inherits the bias, risks, and limitations from its base StarCoder2-15B model. For more information, please refer to the [StarCoder2-15B model card](https://huggingface.co/bigcode/starcoder2-15b). diff --git a/evaluation/README.md b/evaluation/README.md index 2ce2ef4..841342f 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -1,9 +1,29 @@ -# Reproduce the experiments +# Evaluation > [!IMPORTANT] > **General requirements** > > Before you start, make sure you have cloned the repository and you are in the **root directory of the project**. Make sure you installed the required packages with `pip install -e .`. Different package versions may impact the reproducibility of the results. + +## Running EvalPlus with vLLM + +We implemented batched inference in [evaluation/text2code_vllm.py] using [vLLM](https://docs.vllm.ai/en/latest/). This speed up the evaluation significantly: **a greedy decoding run can be finished within 20 seconds**. Here is the command: + +```bash +MODEL=/path/to/your/model +DATASET=humaneval # or mbpp +SAVE_PATH=evalplus-$(basename $MODEL)-$DATASET.jsonl +CUDA_VISIBLE_DEVICES=0 python -m evaluation.text2code_vllm \ + --model_key $MODEL \ + --dataset $DATASET \ + --save_path $SAVE_PATH + +python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH +``` + +## Reproduce StarCoder2-Instruct + +> [!NOTE] > > We obtained the results with the subsequent hardware and environment: > @@ -12,13 +32,13 @@ > > In case you face issues, we provide the raw outputs we generated in the [evalplus_results](evalplus_results) directory. -## Reproduce HumanEval(+) and MBPP(+) +### Reproduce HumanEval(+) and MBPP(+) We pack multiple problems into one batch to speed up the inference. A different batch size may lead to slightly worse/better results due to the floating point round off resulted from the underlying [cuBLAS](https://docs.nvidia.com/cuda/cublas/index.html) optimization. Make sure you set `CUDA_VISIBLE_DEVICES` to the GPU you want to use and `cd`ed to the root directory of the repo. We assume you use device 0 in the following commands. -### HumanEval(+) +#### HumanEval(+) ```bash MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1 @@ -46,7 +66,7 @@ python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH # pass@1: 0.634 ``` -### MBPP(+) +#### MBPP(+) ```bash MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1 @@ -71,4 +91,4 @@ python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH # pass@1: 0.642 # mbpp+ (base + extra tests) # pass@1: 0.526 -``` \ No newline at end of file +``` diff --git a/evaluation/ds_1000.py b/evaluation/ds_1000.py new file mode 100644 index 0000000..9225313 --- /dev/null +++ b/evaluation/ds_1000.py @@ -0,0 +1,264 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Literal, cast +from transformers import AutoTokenizer +from ds1000 import DS1000Dataset, DS1000Problem +from tqdm.auto import tqdm +from transformers import HfArgumentParser + +from star_align.llm_wrapper import ( + GenerationConfig, + ModelContext, + create_infilling_prompt, + get_model_context, +) +from star_align.utils import infer_prompt_template + +from vllm import LLM, SamplingParams + +PROMPT = cast(str, None) + + +@dataclass +class Args: + dataset_path: str + model_key: str + model_name_or_path: str + mode: Literal["Insertion", "Completion"] + output_dir: str + + temperature: float = field(default=0.2) + top_p: float = field(default=0.95) + max_length: int = field(default=1024) + n_samples_per_batch: int = field(default=5) + n_batches: int = field(default=8) + + def to_generation_config(self) -> GenerationConfig: + return GenerationConfig( + # Use max_length to control + max_new_tokens=9999999999999, + top_p=self.top_p, + temperature=self.temperature, + max_length=self.max_length, + ) + + +def postprocess(text: str) -> str: + return text.split("```")[0] + + +def create_prompt(args: Args, tokenizer: AutoTokenizer, problem: DS1000Problem) -> str: + prompt = problem["prompt"] + if args.mode == "Insertion": + prompt = preprocess_insertion_prompt(prompt) + assert prompt.count("[insert]") == 1 + prefix, suffix = prompt.split("[insert]") + prompt = create_infilling_prompt( + model_key=args.model_key, + prefix=prefix, + suffix=suffix, + tokenizer=tokenizer, + ) + else: + assert args.mode == "Completion" + instruction, response_prefix = preprocess_completion_prompt(problem["prompt"]) + prompt = PROMPT.format( + instruction=instruction, + response=response_prefix, + ) + return prompt + + +def generate( + args: Args, + # model_context: ModelContext, + engine: LLM, + problem: DS1000Problem, +): + lib: str = problem["lib"] + model_key = args.model_key.replace("/", "-") + problem_id: str = f"q{problem.problem_id}" + path = Path(args.output_dir) / model_key / lib / args.mode / problem_id + finishing_signal = path / "FINISHED" + if finishing_signal.exists(): + print("Skipping:", path) + return + if not path.exists(): + print("Making directory:", path) + path.mkdir(parents=True, exist_ok=True) + # config = args.to_generation_config() + prompt = create_prompt(args, engine.get_tokenizer(), problem) + print("========PROMPT=======") + print(prompt) + print("========PROMPT=======") + + sampling_params = SamplingParams( + n=args.n_batches * args.n_samples_per_batch, + temperature=args.temperature, + max_tokens=args.max_length, + top_k=-1, + top_p=args.top_p, + stop=["```"], + ) + + # for batch_idx in range(args.n_batches): + # print(f"Generating batch {batch_idx} of {args.n_batches}") + # response = model_context.complete( + # config=config, + # prompts=[prompt] * args.n_samples_per_batch, + # stop_tokens=["```"] if os.getenv("STOP") is not None else None, + # ) + print(f"Generating {args.n_batches * args.n_samples_per_batch} samples") + results = engine.generate(prompt, sampling_params) + assert len(results) == 1 + print("=======RESPOSE[-1]=======") + # postprocess_fn: Callable[[str], str] = ( + # (lambda x: x) if args.mode == "Insertion" else postprocess + # ) + postprocess_fn = postprocess + print(postprocess_fn(results[0].outputs[-1].text)) + # print("=======RESPOSE[-1]=======") + # print("=======RESPOSE[RAW]=======") + # print(response.decoded_outputs[-1]) + # print("=======RESPOSE[RAW]=======") + # exit() + assert len(results[0].outputs) == args.n_batches * args.n_samples_per_batch + for idx, output in enumerate(results[0].outputs): + sample = output.text + sample = postprocess_fn(sample) + # global_index = batch_idx * args.n_samples_per_batch + idx + global_index = idx + output_file = path / f"{global_index}.py" + output_file.write_text(sample) + finishing_signal.touch() + + +def preprocess_completion_prompt(prompt: str) -> tuple[str, str]: + """Preprocess the DS-1000 prompt (Completion mode) into instruction and response prefix""" + # hit = False + if not "SOLUTION START" in prompt: + answer_index = prompt.rindex("A:") + answer = prompt[answer_index + 2 :].strip() + instruction: str = prompt[:answer_index].strip() + if instruction.startswith("Problem:"): + instruction = instruction[len("Problem:") :].strip() + if "### BEGIN SOLUTION" in prompt: + assert prompt.count("") == 1 + assert prompt.count("") == 0 + lines = answer.splitlines(keepends=True) + return_line, result_line, begin_line = lines[-3:] + assert return_line.strip().startswith("# return") + assert result_line.strip().startswith("# ") + assert begin_line.strip() == "### BEGIN SOLUTION" + response = "".join(lines[:-3]).strip() + hint = begin_line.replace("###", "#").replace("BEGIN SOLUTION", "Solution") + response += f"\n{hint}\n" + else: + assert "BEGIN SOLUTION" in prompt + assert prompt.count("") == 2 + assert prompt.count("") == 1 + first_block_start = prompt.index("") + first_block_end = prompt.index("") + second_block_start = prompt.index("", first_block_start + 1) + assert first_block_end < second_block_start + lines = answer.splitlines(keepends=True) + block_end, instruction_line, begin_line, block_start = lines[-4:] + assert begin_line.strip() == "BEGIN SOLUTION" + assert block_start.strip() == "" + if not block_end.strip() == "": + if lines[-6].strip() == "": + response_prefix = lines[:-6] + starting_lines = lines[-5:-2] + else: + assert instruction_line.strip() == "" + response_prefix = lines[:-3] + starting_lines = lines[-2:-2] + else: + response_prefix = lines[:-4] + starting_lines = lines[-3:-2] + starting_lines = [f"# {line.lstrip()}" for line in starting_lines] + response = "".join([*response_prefix, *starting_lines]).strip() + response += "\n# Solution\n" + else: + # hit = True + assert prompt.count("") == 0 + assert prompt.count("") == 0 + assert prompt.strip().endswith("# SOLUTION START") + code_prefix = prompt[: prompt.rindex("# SOLUTION START")].strip() + instruction = f"""Write a solution to the following problem: +```python +{code_prefix} +```""" + response = f"```python\n{code_prefix}\n# Solution\n" + instruction = instruction.replace("", "```python").replace("", "```") + response = response.replace("", "```python").replace("", "```") + # if hit: + # print("[Instruction]") + # print(instruction) + # print("[Response]") + # print(response) + # breakpoint() + return instruction, response + + +def preprocess_insertion_prompt(prompt: str) -> str: + pattern = """ +BEGIN SOLUTION + +[insert] + +END SOLUTION""" + pattern_index = prompt.index(pattern) + # pattern_block = prompt[pattern_index:] + prefix = prompt[:pattern_index] + # hit = False + if pattern + "\n" in prompt: + index = prompt.index("", pattern_index + len(pattern)) + suffix = prompt[index + len("") :] + else: + # hit = True + assert pattern in prompt + suffix = "" + final_prompt = prefix.strip() + "\n[insert]\n" + suffix.strip() + final_prompt = final_prompt.replace("", "```python").replace("", "```") + # if hit: + # print(final_prompt) + # breakpoint() + return final_prompt + + +def main(): + args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + dataset = DS1000Dataset(args.dataset_path, mode=args.mode) + + global PROMPT + if (inferred := os.getenv("INFER")) is not None: + if inferred == "1": + PROMPT = infer_prompt_template(args.model_name_or_path) + else: + PROMPT = infer_prompt_template(inferred) + + print("Using prompt:") + print(PROMPT) + + all_problems = [ + problem + for problems in dataset.data.values() + for problem in problems + if args.mode == "Completion" or problem["lib"] != "Matplotlib" + ] + engine = LLM( + tokenizer=args.model_key, model=args.model_name_or_path or args.model_key + ) + # model_context = get_model_context( + # model_key=args.model_key, + # model_name_or_path=args.model_name_or_path, + # ) + for problem in tqdm(all_problems): + # generate(args, model_context, problem) + generate(args, engine, problem) + + +if __name__ == "__main__": + main() diff --git a/evaluation/text2code.py b/evaluation/text2code.py index 58dd789..4f98159 100644 --- a/evaluation/text2code.py +++ b/evaluation/text2code.py @@ -1,3 +1,15 @@ +import warnings + +if __name__ == "__main__": + # Deprecate warning + warnings.warn( + "This module is deprecated. Use `evaluation.text2code_vllm` instead.", + DeprecationWarning, + ) + # Press y to continue + if input("Do you want to continue? [y/N]: ").lower() != "y": + exit() + import itertools from dataclasses import dataclass from pathlib import Path @@ -10,6 +22,7 @@ from star_align.prompt_template import SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE from star_align.utils import chunked + class Text2CodeProblem(TypedDict): id: str instruction: str @@ -25,6 +38,7 @@ def get_humaneval_raw_problems() -> list[dict]: problems = get_human_eval_plus() return list(problems.values()) + def map_mbpp_problem(p: dict) -> Text2CodeProblem: id = p["task_id"] prompt = p["prompt"] diff --git a/evaluation/text2code_vllm.py b/evaluation/text2code_vllm.py new file mode 100644 index 0000000..cb9375e --- /dev/null +++ b/evaluation/text2code_vllm.py @@ -0,0 +1,215 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal, TypedDict, cast +from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl + +from evoeval.data import get_evo_eval +from transformers import HfArgumentParser + +from star_align.utils import infer_prompt_template, is_base_model + +from vllm import LLM, SamplingParams + + +class Text2CodeProblem(TypedDict): + id: str + prompt: str + instruction: str + response_prefix: str + + +# MBPP_INSTRUCTION = """{nl_description} Your code should satisfy the following assertion: +# ```python +# {assertions} +# ``` +# Enclose your solution in ```python and ```""" + + +def get_mbpp_raw_problems() -> list[dict]: + problems = get_mbpp_plus() + return list(problems.values()) + + +def get_humaneval_raw_problems() -> list[dict]: + problems = get_human_eval_plus() + return list(problems.values()) + + +def get_evoeval_raw_problems(dataset: str): + def get_raw_problems() -> list[dict]: + problems = get_evo_eval(dataset) + return list(problems.values()) + + return get_raw_problems + + +def map_mbpp_problem(p: dict) -> Text2CodeProblem: + id = p["task_id"] + prompt = p["prompt"] + start_index = prompt.index('"""') + end_index = prompt.rindex('"""') + prompt = prompt[start_index + 3 : end_index] + assert_index = prompt.index("assert") + instruction = prompt[:assert_index].strip() + if not instruction.endswith("."): + instruction += "." + assertion = prompt[assert_index:].strip() + instruction = f"""{instruction} + +```python +{assertion} +```""" + prefix = "" + response_prefix = f"""{prefix}```python""" + return Text2CodeProblem( + id=str(id), + prompt=prompt, + instruction=instruction, + response_prefix=response_prefix, + ) + + +def map_humaneval_problem(p: dict) -> Text2CodeProblem: + id = p["task_id"] + prompt = p["prompt"] + prompt = prompt.strip() + # try: + # docstring_index = prompt.index('"""') + # except ValueError: + # docstring_index = prompt.index("'''") + # signature = prompt[:docstring_index].strip() + # Instruction + # instruction = f"""Complete the implementation of the following function: + prompt_header = os.getenv( + "PROMPT_HEADER", "Write a Python function to solve the following task:" + ) + instruction = f"""{prompt_header} +```python +{prompt} +```""" + prefix = "" + prefix_template = os.getenv("PREFIX_TEMPLATE", "```python") + response_prefix = prefix + ( + prefix_template.replace("{prompt}", prompt) + if "{prompt}" in prefix_template + else prefix_template + ) + # response_prefix = f"""{prefix}```python + # {prompt}""" + return Text2CodeProblem( + id=id, + prompt=prompt, + instruction=instruction, + response_prefix=response_prefix, + ) + + +@dataclass(frozen=True) +class Args: + model_key: str + dataset: Literal[ + "humaneval", + "mbpp", + "EvoEval_difficult", + "EvoEval_creative", + "EvoEval_subtle", + "EvoEval_combine", + "EvoEval_tool_use", + "EvoEval_verbose", + "EvoEval_concise", + ] + save_path: str + n_samples_per_problem: int = field(default=1) + max_new_tokens: int = field(default=1024) + top_p: float = field(default=1.0) + temperature: float = field(default=0.0) + model_name_or_path: str | None = None + + +def main(): + args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + raw_problem_fn, map_problem_fn = ( + (get_evoeval_raw_problems(args.dataset), map_humaneval_problem) + if args.dataset.startswith("EvoEval_") + else ( + (get_humaneval_raw_problems, map_humaneval_problem) + if args.dataset == "humaneval" + else (get_mbpp_raw_problems, map_mbpp_problem) + ) + ) + raw_problems = raw_problem_fn() + problems = list(map(map_problem_fn, raw_problems)) + + engine = LLM( + tokenizer=args.model_key, model=args.model_name_or_path or args.model_key + ) + + base_model_prompt = is_base_model(args.model_key) + + stop: str | list[str] = ( + "\n```\n" + if not base_model_prompt + else ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert ", "\n# "] + ) + sampling_params = SamplingParams( + n=args.n_samples_per_problem, + temperature=args.temperature, + max_tokens=args.max_new_tokens, + top_k=-1, + top_p=args.top_p, + stop=stop, + ) + + if base_model_prompt: + print("Base model") + else: + prompt_template = infer_prompt_template( + os.getenv("TOKENIZER") or args.model_name_or_path or args.model_key + ) + # prompt_template = PROMPT_TEMPLATE + print("Using:", prompt_template) + + prompts: list[str] = [] + for problem in problems: + if not base_model_prompt: + prompt = prompt_template.format( + instruction=problem["instruction"], response=problem["response_prefix"] + ) + else: + prompt = problem["prompt"] + prompts.append(prompt) + + results = engine.generate(prompts, sampling_params) + Path(args.save_path).write_text("") + + step = 20 + print_or_not = [idx == 0 or idx % step == 0 for idx in range(len(problems))] + + def sanitize(output: str) -> str: + if not base_model_prompt: + return output.split("```python")[-1].split("```")[0] + for s in stop: + output = output.rsplit(s, 1)[0] + return output + + for problem, prompt, result, print_debug in zip( + problems, prompts, results, print_or_not + ): + if print_debug: + print("[Example Prompt]") + print(prompt) + print("[Example Completion]") + print(result.outputs[0].text) + samples = [ + dict( + task_id=problem["id"], + completion=sanitize(output.text), + ) + for output in result.outputs + ] + write_jsonl(args.save_path, samples, append=True) + + +if __name__ == "__main__": + main() diff --git a/prompts/self-ossinstruct-fewshot.txt b/prompts/self-ossinstruct-fewshot.txt index a709231..8c73095 100644 --- a/prompts/self-ossinstruct-fewshot.txt +++ b/prompts/self-ossinstruct-fewshot.txt @@ -1,15 +1,21 @@ ### System: I->R -Provide the best response to a given instruction. Follow the following steps to craft your response: -1. reason about the given instruction -2. provide a high-quality solution -3. offer a concise explanation -4. write assertions and **invoke them** to verify the correctness your solution +You are an extremely intelligent AI coding assistant. Please provide an accurate and reliable response to each user instruction. After delivering your response, verify its consistency and correctness by writing a series of executable tests. ### System: C->I -Create a set of independent code instructions that are original, different, diverse, and high-quality, where the properties control an instruction's category, language, concepts, and difficulty. +Create a series of independent coding tasks that are original, distinct, diverse, and high-quality, fostering logical thinking. Each task must adhere to specified properties: + +- category: the type of task (e.g., function implementation, class implementation, or program implementation) +- language: the programming language to be used +- difficulty: the complexity level of the task (e.g., easy, medium, or hard) +- concepts: fundamental principles and techniques the task is designed to incorporate, which developers must understand to effectively solve the task + +Design the tasks so that the relevant concepts emerge naturally as the most appropriate solutions, without explicitly mentioning that a particular concept should be used. ### System: S->C -Extract key programming concepts from a given code snippet collected from the open source repositories. Present the concepts as a comma separated list. +Extract key programming concepts from the provided code snippet. Programming concepts refer to the foundational principles and techniques used in programming, which are crucial for developers to master. List these concepts in a comma-separated format. + +### System: S->I +Gain inspiration from the given code snippets and create a series of independent coding tasks that are original, distinct, diverse, and high-quality, fostering logical thinking. ### Example 1 [Code] @@ -25,38 +31,15 @@ def _split_into_chunks(value): value >>= 5 [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: 5-bit chunk encoding with bitwise shifts, ASCII value manipulation, continuation bit signaling difficulty: medium +concepts: 5-bit chunk encoding with bitwise shifts, ASCII value manipulation, continuation bit signaling [Instruction] Create a Python function to encode an integer into a sequence of encoded characters, where each character represents a 5-bit chunk of the integer. The ASCII value of each chunk character is formed by adding 63 to the 5-bit chunk's value. For all chunks except the last one, set the 6th bit to 1 to indicate continuation. Employ bitwise operations for this encoding process. -[Solution] -```python -def encode_integer_to_chars(value: int) -> str: - if value == 0: - return chr(63) - encoded_chars = [] - while value > 0: - chunk = (value & 31) + 63 - value >>= 5 - if value > 0: - chunk |= 32 - encoded_chars.append(chr(chunk)) - return ''.join(reversed(encoded_chars)) - -# Tests -assert encode_integer_to_chars(0) == chr(63) -assert encode_integer_to_chars(1) == chr(63 + 1) -assert encode_integer_to_chars(31) == chr(63 + 31) -assert encode_integer_to_chars(int(0b1)) == chr(0b1 + 63) -assert encode_integer_to_chars(int(0b1000001)) == chr(0b10 + 63) + chr(63 + 0b100001) -``` - [Response] -[Reasoning] To encode an integer into a sequence of encoded characters based on the specified criteria, let's follow these steps: 1. Break down the integer into 5-bit chunks: Extract 5-bit chunks from the integer using using `& 31` (`0x1F`), because the binary representation of `31` (`11111`) has the last five bits set, which will effectively extract the 5 bits we are interested in. We start from the least significant bits (LSB). @@ -65,37 +48,28 @@ To encode an integer into a sequence of encoded characters based on the specifie 4. Convert to characters: Use `ord` and `chr` to convert each chunk (now effectively 6 bits due to the continuation bit) into a character by adding 63 to its value. This ensures that the character is within a printable ASCII range. 5. Assembly: Assemble these characters into the final encoded string while maintaining the order of the chunks through `reversed`. -[Implementation] Here's how you can implement this function: ```python def encode_integer_to_chars(value: int) -> str: - # Base case for 0 if value == 0: return chr(63) encoded_chars = [] while value > 0: - # Extract the 5-bit chunk and add 63 to its value chunk = (value & 31) + 63 value >>= 5 - # If there's still more of the value left, set the continuation bit if value > 0: chunk |= 32 encoded_chars.append(chr(chunk)) - # Reverse the list of characters to start with the least significant chunk - # and join them into a single string return ''.join(reversed(encoded_chars)) ``` -[Explanation] -This function processes the integer bit by bit in 5-bit chunks, applying bitwise operations to extract and manipulate these chunks, then converts each chunk into a character using the specified encoding scheme. This approach efficiently handles the encoding in a loop until all bits of the integer have been processed. - -Note that this function is designed for non-negative integers and does not support negative integers by default because negative numbers are represented in two's complement form, where the most significant bit (MSB) indicates the sign. In Python, right-shifting (`>>`) a negative integer preserves the sign bit, effectively making the never executed due to the `value > 0` condition. If you change it to `value != 0`, it will then lead to an infinite loop if the same encoding logic is applied to negative integers without modifications. +Keep in mind that this function is designed for non-negative integers. In Python, right-shifting (`>>`) a negative integer preserves the sign bit, which could lead to infinite loops or other incorrect behaviors if the same encoding approach is applied to negative integers without appropriate adjustments. [Tests] -You can validate this function with a few test cases to ensure correctness: +You can validate this function with the following test cases that cover scenarios from zero to a number requiring more than one chunk, ensuring the function operates as expected: ```python assert encode_integer_to_chars(0) == chr(63) @@ -105,8 +79,6 @@ assert encode_integer_to_chars(int(0b1)) == chr(0b1 + 63) assert encode_integer_to_chars(int(0b1000001)) == chr(0b10 + 63) + chr(63 + 0b100001) ``` -These tests cover scenarios from zero to a number requiring more than one chunk, ensuring the function operates as expected. - ### Example 2 [Code] def ceil_shift(n, b): @@ -123,37 +95,15 @@ def ceil_shift(n, b): else: [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: bitwise operation for division, ceiling without division or float, input validation with assertions, non-negative integer checks difficulty: medium +concepts: bitwise operation for division, ceiling without division or float, input validation with assertions, non-negative integer checks [Instruction] Develop a Python function `ceil_divide_without_div` that computes the ceiling of `n / 2**b` without using division, floating-point operations, built-in functions like `math.ceil`, or converting it into computing the floor. Instead, use bit manipulation to achieve the result. You write assertions to ensure both `n` and `b` are non-negative integers. -[Solution] -```python -def ceil_divide_without_div(n: int, b: int) -> int: - assert isinstance(n, int) and n >= 0, "n must be a non-negative integer" - assert isinstance(b, int) and b >= 0, "b must be a non-negative integer" - mask = (1 << b) - 1 - divisible = (n & mask) == 0 - result = n >> b - if not divisible: - result += 1 - return result - -# Tests -import math -assert ceil_divide_without_div(10, 3) == math.ceil(10 / 2**3) -assert ceil_divide_without_div(17, 4) == math.ceil(17 / 2**4) -assert ceil_divide_without_div(100, 5) == math.ceil(100 / 2**5) -assert ceil_divide_without_div(0, 3) == math.ceil(0 / 2**3) -assert ceil_divide_without_div(1, 0) == math.ceil(1 / 2**0) -``` - [Response] -[Reasoning] We can leverage the properties of bitwise operations to develop a Python function that computes the ceiling of `n / 2**b` without using division, floating-point operations, or any built-in functions like `math.ceil`, and strictly avoiding converting the problem into a floor computation. In the context, dividing by `2**b` is equivalent to right-shifting `n` by `b` bits in binary representation when `n` is divisible by `2**b`. Otherwise, this approach inherently computes the floor of the division. To adjust for the ceiling without directly calculating it, we can consider two cases: @@ -165,7 +115,6 @@ To determine the divisibility of `n` by `2**b`, we can check if the first `b` bi We'll also include assertions to ensure that `n` and `b` are non-negative integers. -[Implementation] Here is the implementation of `ceil_divide_without_div`: ```python @@ -173,24 +122,18 @@ def ceil_divide_without_div(n: int, b: int) -> int: assert isinstance(n, int) and n >= 0, "n must be a non-negative integer" assert isinstance(b, int) and b >= 0, "b must be a non-negative integer" - # Constructing the mask to check divisibility mask = (1 << b) - 1 divisible = (n & mask) == 0 - result = n >> b - - # If n is not divisible by 2**b, increment the result to get the ceiling if not divisible: result += 1 - return result ``` -[Explanation] This function calculates the ceiling of `n / 2**b` purely through bitwise manipulation and integer arithmetic. The assertions ensure the input constraints are met, as specified. [Tests] -To verify the correctness of the function, we can run the following tests: +To verify the correctness of the function, we can compare the results of our function with the expected results obtained using `math.ceil` to ensure the correctness of the ceiling division operation: ```python import math @@ -202,8 +145,6 @@ assert ceil_divide_without_div(0, 3) == math.ceil(0 / 2**3) assert ceil_divide_without_div(1, 0) == math.ceil(1 / 2**0) ``` -These tests compare the results of our function with the expected results obtained using `math.ceil` to ensure the correctness of the ceiling division operation. - ### Example 3 [Code] isprime = n >= 2 and 1 or 0 @@ -221,54 +162,15 @@ def prime(x): lastn = lastn + 1 # Check the next number [Property] -category: code generation (class implementation) +category: class implementation language: Python -concepts: prime number memoization, n-th prime calculation, efficiency in prime checks difficulty: medium +concepts: prime number memoization, n-th prime calculation, efficiency in prime checks [Instruction] Implement a Python class with a method `find_nth_prime_memoized(n)` that returns the nth prime number, where n is a positive integer. Utilize memoization to store already found prime numbers in a list within the class to avoid recalculating primes for multiple method calls. Ensure the method efficiently updates and references this list to find new primes as needed. Additionally, enhance performance by avoiding unnecessary prime checks. -[Solution] -```python -class PrimeFinder: - def __init__(self): - self.primes = [2, 3] - - def is_prime(self, num: int) -> bool: - if num < 2: - return False - for prime in self.primes: - if prime * prime > num: - break - if num % prime == 0: - return False - return True - - def find_nth_prime_memoized(self, n: int) -> int: - if n <= len(self.primes): - return self.primes[n - 1] - - current_num = self.primes[-1] + 2 - while len(self.primes) < n: - if self.is_prime(current_num): - self.primes.append(current_num) - current_num += 2 - - return self.primes[n - 1] - -# Tests -first_few_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] -prime_finder = PrimeFinder() -for i, prime in enumerate(first_few_primes, 1): - assert prime_finder.find_nth_prime_memoized(i) == prime -# Cached result -for i, prime in enumerate(first_few_primes, 1): - assert prime_finder.find_nth_prime_memoized(i) == prime -``` - [Response] -[Reasoning] Implementing the required Python class with the method `find_nth_prime_memoized(n)` involves several key steps: 1. The class should have a constructor `__init__` that initializes the cache list to store the memoized prime numbers. Initially, the list can be seeded with the first two prime numbers, 2 and 3. @@ -278,7 +180,6 @@ Implementing the required Python class with the method `find_nth_prime_memoized( - It only checks up to the square root of the number to be tested, implemented by comparing the square of the prime with the number. - It skips even numbers by incrementing by 2 for each prime check. This is sound since the cache is initialized with the first **two** primes, 2 and 3. -[Implementation] Here is how you could implement such a class: ```python @@ -297,42 +198,31 @@ class PrimeFinder: return True def find_nth_prime_memoized(self, n: int) -> int: - # Return the n-th prime directly if already in cache if n <= len(self.primes): return self.primes[n - 1] - # Start checking from the number after the last prime in cache current_num = self.primes[-1] + 2 while len(self.primes) < n: if self.is_prime(current_num): - # Cache the new prime self.primes.append(current_num) - # Skip even numbers current_num += 2 return self.primes[n - 1] ``` -[Explanation] -This class `PrimeFinder` initializes with a cache of the first two primes. The `is_prime` method checks if a number is prime by dividing it with cached primes up to its square root. The `find_nth_prime_memoized` method uses this helper method to find and cache primes up to the requested `n`th one. Subsequent calls leverage cached primes for efficiency. - [Tests] -Here's how you can test this class: +Here's how you can test this class by verifying that the class correctly finds the first few primes and that subsequent calls return the memoized results directly: ```python -# Test that the class correctly finds the first few primes first_few_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] prime_finder = PrimeFinder() for i, prime in enumerate(first_few_primes, 1): assert prime_finder.find_nth_prime_memoized(i) == prime -# Test that the memoized results are returned directly for i, prime in enumerate(first_few_primes, 1): assert prime_finder.find_nth_prime_memoized(i) == prime ``` -These tests verify that the class correctly finds the first few primes and that subsequent calls return the memoized results directly. - ### Example 4 [Code] return (a + 1) * (b + 1) - 1 @@ -352,10 +242,10 @@ def multiset_powerset(multiset): while i >= 0 and not changed: [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: special formula implementation, odd and even number handling, function composition difficulty: easy +concepts: special formula implementation, odd and even number handling, function composition [Instruction] Here are two special formulas: @@ -373,31 +263,9 @@ $$ Write a Python function to return $f_2(f_1(a, b))$ for given `a` and `b`. -[Solution] -```python -def f1(a, b): - return (a + 1) * (b + 1) - 1 - -def f2(k): - if k % 2 == 0: - return (k**2) / 2 + 2 * k + 1 - else: - return ((k + 1)**2) / 2 + k + 1 - -def f1_f2_combined(a, b): - return f2(f1(a, b)) - -# Tests -assert f1_f2_combined(2, 3) == f2(f1(2, 3)) -assert f1_f2_combined(4, 5) == f2(f1(4, 5)) -assert f1_f2_combined(6, 7) == f2(f1(6, 7)) -``` - [Response] -[Reasoning] Based on the formulas you provided, we can define two Python functions, `f1(a, b)` and `f2(k)`, respectively, and then combine them to calculate $f2(f1(a, b))$ for given `a` and `b`. -[Implementation] Here is the implementation: ```python @@ -405,17 +273,17 @@ def f1(a, b): return (a + 1) * (b + 1) - 1 def f2(k): - if k % 2 == 0: # Check if k is even + if k % 2 == 0: return (k**2) / 2 + 2 * k + 1 - else: # k is odd + else: return ((k + 1)**2) / 2 + k + 1 def f1_f2_combined(a, b): return f2(f1(a, b)) ``` -[Explanation] In this implementation: + - `f1(a, b)` calculates the first formula you've given. - `f2(k)` calculates the second formula, with a conditional branch for whether `k` is even or odd. - `f1_f2_combined(a, b)` uses the output of `f1(a, b)` as the input for `f2(k)` and returns the result. @@ -447,44 +315,19 @@ def tripletsSorting(nums, t): for i in range(len(nums)): [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: insertion point for a sorted array, optimized time complexity difficulty: easy +concepts: insertion point for a sorted array, optimized time complexity [Instruction] Design a Python function that takes a sorted array and a target value, and return a valid index where target can be inserted to maintain the array's sorted order. Optimize the function to run in logarithmic time complexity. For example, given `array = [1, 3, 5, 5, 6]` and `target = 5`, the function should return either 2 or 3 because 5 presents at both indices 2 and 3. -[Solution] -```python -from typing import List - -def search_insert_position(nums: List[int], target: int) -> int: - left, right = 0, len(nums) - 1 - - while left <= right: - mid = (left + right) // 2 - if nums[mid] == target: - return mid - elif nums[mid] < target: - left = mid + 1 - else: - right = mid - 1 - return left - -# Tests -assert search_insert_position([1, 3, 5, 5, 6], 5) in [2, 3] -assert search_insert_position([1, 3, 5, 5, 6], 2) == 1 -assert search_insert_position([1, 3, 5, 5, 6], 7) == 5 -``` - [Response] -[Reasoning] To solve this problem efficiently and ensure logarithmic time complexity, we can use a binary search algorithm. Compared with a standard binary search that looks for an exact match, we can modify the algorithm such that when the target is not found, we return the `left` bound, which represents the index where the target can be inserted to maintain the array's sorted order. Since any valid index for insertion is acceptable, we can direclty return the index if there is an exact match. -[Implementation] Here is a Python function that implements this approach: ```python @@ -497,1137 +340,1277 @@ def search_insert_position(nums: List[int], target: int) -> int: as `target` can be inserted at index 2 or 3 to maintain the sorted order. """ left, right = 0, len(nums) - 1 - while left <= right: mid = (left + right) // 2 - - # Directly return the index if there's an exact match if nums[mid] == target: return mid elif nums[mid] < target: left = mid + 1 else: right = mid - 1 - - # At this point, `left` is the correct insertion index return left ``` -[Explanation] This implementation ensures that when `target` is not found, `left` represents the correct insertion index. This is because at the end of the loop, all elements to the left of `left` are smaller than `target`, all elements to the right of `left` are equal to or larger than `target`, and `left > right`. Therefore, `left` is the correct index for insertion to maintain the sorted order. [Tests] -To test this function, you can use the example you provided: +You can test the function with scenarios where the target is found at multiple indices and where the target is not present in the array: ```python -# Provided example assert search_insert_position([1, 3, 5, 5, 6], 5) in [2, 3] -# Additional cases assert search_insert_position([1, 3, 5, 5, 6], 2) == 1 assert search_insert_position([1, 3, 5, 5, 6], 7) == 5 +assert search_insert_position([1, 3, 5, 5, 6], 0) == 0 ``` -These tests cover the scenario where the target is found at multiple indices, as well as cases where the target is not present in the array but needs to be inserted at the correct position to maintain the sorted order. - ### Example 6 [Code] -files = ['kitti_all_train.data', - 'kitti_all_train.labels', - 'kitti_all_test.data', - 'kitti_all_test.labels'] - -for file in files: - if file not in os.listdir(data_dir): - zip_path = os.path.join(data_dir, 'kitti_features.zip') - target_path = os.path.dirname(zip_path) - print("Extracting {} to {}...".format(zip_path, target_path)) - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(target_path) - print("Done.") - break - -X_train = np.loadtxt(os.path.join(data_dir, files[0]), np.float64, skiprows=1) -y_train = np.loadtxt(os.path.join(data_dir, files[1]), np.int32, skiprows=1) -X_test = np.loadtxt(os.path.join(data_dir, files[2]), np.float64, skiprows=1) +def decompress(self): + source = self.compressed + if isinstance(source, (bytes, bytearray)): + return self.decompress_bytes() + pos = 0 + node = self.root + res = bytearray() + + while pos < len(source): + code = int(source[pos]) + child = node.children[code] + if child.is_leaf: + res.append(child) + node = self.root + else: + node = child + pos += 1 + + return bytes(res) [Property] -category: code generation (script implementation) +category: function implementation language: Python -concepts: file I/O, data loading and processing with NumPy, zip file handling difficulty: medium +concepts: decoding compressed data, binary tree traversal, byte array processing [Instruction] -I have a zip archive `kitti_features.zip` with four files (no top-level directory): `kitti_all_train.data`, `kitti_all_train.labels`, `kitti_all_test.data`, and `kitti_all_test.labels`. I want to load these files into numpy arrays for a machine learning project. The `.data` files are floating points and the `.labels` files are integers, both with a header row. +Write a Python function `huffman_decompress` that takes two arguments: `compressed_data`, a byte array of 0/1 sequence representing Huffman compressed data, and `huffman_tree`, a Huffman tree that stores the actual data (each leaf is associated with a character). The function should return the original uncompressed data as `bytes`. Define the `HuffmanNode` class yourself as a general tree structure with a `children` list. -Can you write an initial Python script that extract the zip archive (if not already extracted) and loads the training and testing datasets + labels into appropriate numpy variables. +[Response] +To accomplish the Huffman decompression, we first need to define the `HuffmanNode` class, which, as required, should be a general tree structure with a `children` list and an optional `value` attribute to store the character associated with a leaf node. -[Solution] +Then, we'll implement the `huffman_decompress` function. This function will traverse the Huffman tree according to the bits in `compressed_data`. When it encounters a leaf node, the associated character will be appended to the result, and the traversal restarts from the root of the tree. This process continues until the entire `compressed_data` is processed. -[Response] -To achieve this, you'll need to use `zipfile` to extract the files from the zip archive and `numpy` to load the data into arrays. Since the `.data` files contain `float64` numbers and the `.labels` files contain `int32` integers, both with a header row, to process them appropriately, you can use `np.loadtxt`, specifying data types as `np.float64` for `.data` files and `np.int32` for `.labels` files, while also including `skiprows=1` to bypass the header. Additionally, given that the zip archive doesn't have a top-level directory, the files will be extracted directly into the target directory. +You said that `compressed_data` is a byte array of 0/1 sequence, so I assume that each byte, not each bit, represents a step in the Huffman tree. In practice, it is more efficient -Here's a Python script that extracts the files if they haven't been extracted yet and loads the data into numpy arrays: +Here's how we can implement this: ```python -import os -import numpy as np -import zipfile +class HuffmanNode: + def __init__(self, value=None): + self.value = value + self.children = [] -def extract_files(zip_path, file_paths, extract_to='.'): - """Extracts files from a zip archive if they don't already exist.""" - all_files_extracted = all(os.path.exists(file_path) for file_path in file_paths) - if not all_files_extracted: - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - zip_ref.extractall(extract_to) +def huffman_decompress(compressed_data, huffman_tree): + result = [] + node = huffman_tree -def load_data(file_name, dtype): - """Loads data from a file into a numpy array, skipping the header row.""" - return np.loadtxt(file_name, dtype=dtype, skiprows=1) + # Each byte represents a bit in `compressed_data` + for bit in compressed_data: + if len(node.children) > 0: + node = node.children[int(bit)] + if len(node.children) == 0: + result.append(node.value) + node = huffman_tree -def main(): - zip_path = 'kitti_features.zip' - data_files = { - 'train_data': 'kitti_all_train.data', - 'train_labels': 'kitti_all_train.labels', - 'test_data': 'kitti_all_test.data', - 'test_labels': 'kitti_all_test.labels' - } - extract_files(zip_path, data_files.values()) + return "".join(result).encode() +``` - # Load the datasets - X_train = load_data(data_files['train_data'], np.float64) - y_train = load_data(data_files['train_labels'], np.int32) - X_test = load_data(data_files['test_data'], np.float64) - y_test = load_data(data_files['test_labels'], np.int32) +In this implementation, each byte in `compressed_data` represents a bit, guiding the traversal of the Huffman tree. Characters are appended upon reaching leaf nodes and the result is then encoded to `bytes`. - # You now have the datasets and labels loaded into numpy arrays - print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}") - print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}") +[Tests] +You can test this function with a sample Huffman tree and compressed data to verify that the decompression works correctly: -if __name__ == "__main__": - main() +```python +root = HuffmanNode() +root.children = [HuffmanNode('A'), HuffmanNode('B')] +root.children[0].children = [HuffmanNode('C'), HuffmanNode('D')] +root.children[1].children = [HuffmanNode('E'), HuffmanNode('F')] +compressed_data = bytearray([0, 1, 0, 0, 1, 1]) +# 01 -> D, 00 -> C, 11 -> F +assert huffman_decompress(compressed_data, root) == b'DCF' ``` -This script defines separate functions for data extraction and loading. In the `main` function, it first checks if the data and label files are already extracted. If not, it extracts them from the specified zip archive. Then, it loads the data into `numpy` arrays, ensuring the desired data types and to skip the header row as you requested. Remember to have `numpy` installed in your environment, and ensure that the zip file is in the correct location relative to the script. - ### Example 7 [Code] -template > -class Stack -{ -public: - Stack(); - - Stack(std::initializer_list); - - const bool empty() const; - - const size_t size() const; - - void push(constT&); - - void pop(); - +def format_size(num): + """http://stackoverflow.com/a/1094933 + """ + for x in ['bytes', 'KB', 'MB', 'GB']: + if num < 1024.0 and num > -1024.0: + return "%3.1f%s" % (num, x) + num /= 1024.0 + return "%3.1f%s" % (num, 'TB') +assert format_size(1024**2 - 1) == '1024.0KB' +assert format_size(1024*512) == '512.0KB' [Property] -category: code generation (class implementation) -language: C++ -concepts: template-based generic stack implementation, STL container usage -difficulty: medium +category: function implementation +language: Python +difficulty: easy +concepts: arithmetic operations for size conversion, rounding numbers, dynamic unit selection, string interpolation [Instruction] -Implement a generic stack class in C++ that allows the user to create a stack for any data type. The class should be template-based, enabling type flexibility. Your implementation should support basic stack operations such as `push`, `pop`, `top`, and checking if the stack is `empty`. Additionally, include a method to retrieve the current size of the stack. For the underlying container, use the Standard Template Library's `std::vector` as the default storage mechanism. This will provide a dynamic array that can efficiently resize itself as needed. - -[Solution] +Create a Python function to format a file size in bytes into a human-readable string representation, using 'bytes', 'KB' (kilobytes), 'MB' (megabytes), 'GB' (gigabytes), or 'TB' (terabytes) units. The output should be rounded to one decimal place and follow the format: "{X.Y}{Unit}", where "Unit" is the appropriate size unit and "X.Y" is the rounded size. For example, `format_size(1024 * 512)` should return `'512.0KB'`, and `format_size(1024**2 - 1)` should return `'1024.0KB'`. [Response] -Sure, I'll provide a template-based `Stack` class in C++ that meets your requirements. The class will use `std::vector` as the underlying container and provide methods for `push`, `pop`, `top`, `empty`, and `size` operations. Besides a default constructor, I'll include a constructor that takes an `std::initializer_list` to initialize the stack with a variable number of elements. I'll also employ `stdexcept` to handle exceptions for operations on an empty stack. - -Here is a possible implementation: - -```cpp -#include -#include -#include - -template -class Stack { -private: - std::vector elements; // Vector to store stack elements - -public: - Stack() = default; - Stack(std::initializer_list init) : elements(init) {} - - void push(const T& element) { - elements.push_back(element); - } - - void pop() { - if (empty()) { - throw std::out_of_range("Stack<>::pop(): empty stack"); - } - elements.pop_back(); - } - - // Get the top element of the stack (const version) - const T& top() const { - if (empty()) { - throw std::out_of_range("Stack<>::top(): empty stack"); - } - return elements.back(); - } - - // Get the top element of the stack (non-const version) - T& top() { - if (empty()) { - throw std::out_of_range("Stack<>::top(): empty stack"); - } - return elements.back(); - } +To create such a function, we can use a loop to iterate through a predefined list of units (`['bytes', 'KB', 'MB', 'GB', 'TB']`) and check if the given number of bytes fits into the current unit by comparing it with 1024.0 and -1024.0 (for handling negative sizes). We'll keep dividing the number by 1024.0 until the correct unit is found where the number is less than 1024.0. Then, we'll format the result using the appropriate unit and the rounded size. - bool empty() const { - return elements.empty(); - } +```python +def format_size(size_bytes): + units = ['bytes', 'KB', 'MB', 'GB', 'TB'] - std::size_t size() const { - return elements.size(); - } -}; -``` + # Directly return for 0 bytes to avoid division by zero in loop + if size_bytes == 0: + return "0.0bytes" -Here is a simple example showing how to use this `Stack` class: - -```cpp -#include -#include - -int main() { - // Stack of integers (using default constructor) - Stack intStack; - intStack.push(5); - intStack.push(25); - std::cout << "Top element in intStack: " << intStack.top() << std::endl; - intStack.pop(); - std::cout << "Top element in intStack after pop: " << intStack.top() << std::endl; - - // Stack of strings (using initializer list) - Stack stringStack{"Hello", "World"}; - std::cout << "Top element in stringStack: " << stringStack.top() << std::endl; - stringStack.pop(); - std::cout << "Top element in stringStack after pop: " << stringStack.top() << std::endl; - - return 0; -} + num = abs(size_bytes) + unit = units[0] + for u in units[1:]: + if num < 1024.0: + break + num /= 1024.0 + unit = u + formatted_size = f"{num:.1f}{unit}" + return formatted_size if size_bytes >= 0 else f"-{formatted_size}" ``` -This example demonstrates the usage of the `Stack` class with both integer and string types, different initialization methods, and basic stack operations. Note that attempting to `pop` or `top` from an empty stack will result in an `std::out_of_range` exception. Therefore, in practical applications, you should catch these exceptions or verify if the stack is empty before executing these operations. +This implementation handles both positive and negative file sizes and returns the formatted size as a string, even though negative sizes are not common in real-world scenarios. -### Example 8 -[Code] - Ok((p, uids)) - } -} - -impl MyRedisPool { - pub(crate) fn get_posts( - &self, - ids: Vec, - ) -> impl Future, Vec), ResError>> + '_ { - self.get_cache_with_perm_with_uids(ids, crate::handler::cache::POST_U8) - } +[Tests] +We can write tests that cover both positive and negative sizes across different units to validate the correctness of the `format_size` function: - pub(crate) async fn get_posts_old( - &self, - tid: u32, - page: usize, - ) -> Result<(Vec, Vec), ResError> { - let key = format!("topic:{}:posts_time_created", tid); - self.get_cache_with_uids_from_zrange(key.as_str(), page, crate::handler::cache::POST_U8) - .await - } +```python +test_data = [ + (1024 ** 2, "1.0MB"), + (1024 ** 2 - 1, "1024.0KB"), + (1024 ** 3 * 512, "512.0GB"), + (1024 ** 4 * 512, "512.0TB"), +] +for size, expected_output in test_data: + assert format_size(size) == expected_output + assert format_size(-size) == f"-{expected_output}" +``` - pub(crate) async fn get_posts_pop( - &self, - tid: u32, +### Example 8 +[Code] +def unify_stringlist(L: list): + """ Adds asterisks to strings that appear multiple times, so the resulting + list has only unique strings but still the same length, order, and meaning. + For example: + unify_stringlist(['a','a','b','a','c']) -> ['a','a*','b','a**','c'] + """ + assert(all([isinstance(l,str) for l in L])) + return [L[i]+"*"*L[:i].count(L[i]) for i in range(len(L))] +assert unify_stringlist(list("abc")) == list("abc") [Property] -category: code generation (struct implementation) -language: Rust -concepts: asynchronous programming through async/await, Rust error handling patterns, Redis interaction, pagination support -difficulty: very hard +category: function implementation +language: Python +difficulty: easy +concepts: list comprehension, unique string creation, precondition checking for type validation [Instruction] -Build a Rust-based `RedisPool` struct that provides asynchronous methods for caching and retrieving posts by their IDs. Ensure your implementation leverages the `tokio` runtime for asynchronous operations and the `redis` crate for interaction with Redis. Your `RedisPool` should include methods to cache posts, retrieve cached posts by IDs, and support pagination for listing posts associated with a topic. Use appropriate error handling to manage potential Redis operation failures. +Design a Python function to encode a list of strings into a unique list with the same length, order, and meaning. The new list is formed by appending asterisks (*) to duplicate strings. For instance, given ['a', 'a', 'b', 'a', 'c'], the function should return ['a', 'a*', 'b', 'a**', 'c'], while for ['a', 'b', 'c'] it should return ['a', 'b', 'c'] without any changes. Write assertions to ensure the input is a list of strings. -[Solution] +Solve the problem in two lines of code, one for the precondition check and the other for the main logic. [Response] -To implement a `RedisPool` struct that provides asynchronous methods for caching and retrieving posts by their IDs, while leveraging the `tokio` runtime for asynchronous operations and the `redis` crate for interaction with Redis, you'll need to follow a few steps. Below is an example implementation that outlines the structure and functionality you've described. - -First, ensure you have the necessary dependencies in your `Cargo.toml` file: - -```toml -[dependencies] -# Please check for the latest versions of these crates -tokio = { version = "1", features = ["full"] } -redis = "0.24.0" -async-trait = "0.1.50" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +You can achieve this functionality by first ensuring the input meets your requirements using an `assert` statement with list comprehension. Then, use another list comprehension to iterate through the input list and append `'*'` to the strings. The number of `'*'` is the number of occurrences of the string before the current index. + +Here's how you can do it in two lines of Python code: + +```python +def encode_strings(input_list): + assert all(isinstance(item, str) for item in input_list), "Input must be a list of strings" + return [s + '*' * input_list[:i].count(s) for i, s in enumerate(input_list)] ``` -Next, you'll need a `Post` struct and `RedisPool` with methods to cache posts, retrieve cached posts by IDs, and support pagination for listing posts associated with a topic. You also define a custom error type for error handling: +Note that although the function is concise, it is not the most efficient solution for large lists, as the `count` method has a time complexity of O(n) for each iteration. -```rust -use serde::{Serialize, Deserialize}; +[Tests] +You can test the function with the provided examples to ensure it correctly encodes the strings as expected: -#[derive(Serialize, Deserialize)] -struct Post { - id: u32, - title: String, - content: String, - topic: String, -} +```python +assert encode_strings(['a', 'a', 'b', 'a', 'c']) == ['a', 'a*', 'b', 'a**', 'c'] +assert encode_strings(['a', 'b', 'c']) == ['a', 'b', 'c'] +``` -#[derive(Debug)] -enum RequestError { - RedisError(redis::RedisError), - SerdeError(serde_json::Error), -} +### Example 9 +[Code] +def _create_folds_list(data, count): + """ + Creates folds from the given data. -impl From for RequestError { - fn from(err: redis::RedisError) -> Self { - RequestError::RedisError(err) - } -} + :param data: the data to fold + :param count: the number of folds to create -impl From for RequestError { - fn from(err: serde_json::Error) -> Self { - RequestError::SerdeError(err) - } -} -``` + :return: a list of folds + """ -After that, define your `RedisPool` struct with a field for the Redis connection pool. For simplicity, you'll directly use `AsyncConnection` here, but in a real-world scenario, you might want to use a connection pool. Implement the required asynchronous methods using the `async_trait` crate for defining async methods in traits: - -```rust -struct RedisPool { - conn: redis::AsyncConnection, -} - -use async_trait::async_trait; -use redis::AsyncCommands; - -impl RedisPool { - pub async fn cache_posts(&self, posts: Vec) -> Result<(), RequestError> { - let mut conn = &self.conn; - for post in posts.iter() { - let key = format!("post:{}", post.id); - let _: () = redis::cmd("HMSET") - .arg(&key) - .arg(serde_json::to_string(&post)?) - .query_async(&mut conn) - .await - .map_err(|e| e.into())?; - } - Ok(()) - } + fold_count = len(data) / count + folds = list() - pub async fn get_posts_by_ids(&self, ids: Vec) -> Result, RequestError> { - let mut conn = &self.conn; - let keys: Vec = ids.iter().map(|id| format!("post:{}", id)).collect(); - let posts: Vec = conn.get(keys).await.map_err(|e| e.into())?; - posts.iter().map(|p| serde_json::from_str(p)).collect() - } + for fold_index in range(count): + low = int(fold_index * fold_count) + high = int((fold_index + 1) * fold_count) - pub async fn get_posts_for_topic_paginated(&self, tid: u32, page: usize, page_size: usize) -> Result<(Vec, Vec), RequestError> { - let mut conn = &self.conn; - let start = page * page_size; - let end = (page + 1) * page_size - 1; - let key = format!("topic:{}", tid); - let ids: Vec = conn.zrange(key, start as isize, end as isize).await.map_err(|e| e.into())?; - let posts = self.get_posts_by_ids(ids).await?; - Ok((posts, ids)) - } -} + fold = data[low:high] + folds.append(fold) + + return folds + +[Property] +category: function implementation +language: Python +difficulty: medium +concepts: data partitioning and slicing, learning from input-output examples + +[Instruction] +Construct a Python function `create_folds(data, count)` to create a list of folds from the given data, where each fold is a subsequence of the original data. The length of each fold should be approximately equal with at most a difference of 1. The function should satisfy the following assertions: + +```python +assert create_folds(range(12), 2) == [range(0, 6), range(6, 12)] +assert create_folds(range(8), 2) == [range(0, 4), range(4, 8)] +assert create_folds(range(25), -5) == [] +assert create_folds(range(6), 2) == [range(0, 3), range(3, 6)] ``` -To use `RedisPool`, you'll first need to establish a connection to Redis and then interact with the pool for caching or retrieving posts: +[Response] +To create a Python function that generates a list of folds from given data with each fold having approximately equal length (with at most a difference of 1 between any two folds), you can follow these steps: + +1. Return an empty list immediately if `count` is non-positive since it's not possible to create a valid number of folds. +2. Divide the total length of the data by `count`, rounding down, to get the minimum number of items each fold should have. +3. Calculate the remainder to see how many extra items need to be distributed among the folds. They are distributed one by one to the first `remainder` folds. +4. Use a loop to create each fold, adjusting the starting and ending indices based on the calculations above to ensure each fold has the correct number of items. + +Here's how you can implement this: -```rust -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = redis::Client::open("redis://127.0.0.1/")?; - let conn = client.get_async_connection().await?; - let pool = RedisPool { conn }; +```python +def create_folds(data, count): + # Return an empty list if count is non-positive + if count <= 0: + return [] - // Example usage - let posts = vec![Post { id: 1, title: "Hello".into(), content: "World".into(), topic: "rust".into() }]; - pool.cache_posts(posts).await?; + data_length = len(data) + fold_size = data_length // count + remainder = data_length % count - Ok(()) -} + folds = [] + start_index = 0 + for i in range(count): + # Add 1 element to each of the first `remainder` folds + end_index = start_index + fold_size + (1 if i < remainder else 0) + folds.append(range(start_index, end_index)) + start_index = end_index + return folds ``` -This basic implementation covers the essentials you need: caching posts, retrieving them by ID, and supporting pagination. You might need to adjust for specifics like connection pooling or handling serialization/deserialization more gracefully, but this gives you a good starting point. +This implementation ensures that each fold has an equal length with at most a difference of 1 between any two folds. It handles edge cases, such as a non-positive `count`, by returning an empty list as specified. -### Example 9 -[Code] -def test_fill_latitude_longitude(): - # given +[Tests] +You can test your code with the provided assertions to verify that the function works correctly: - posis = data.load_from_json_file("test/data/test_posis.json", data.POSI) - assert len(posis) == 4 +```python +assert create_folds(range(12), 2) == [range(0, 6), range(6, 12)] +assert create_folds(range(8), 2) == [range(0, 4), range(4, 8)] +assert create_folds(range(25), -5) == [] +assert create_folds(range(6), 2) == [range(0, 3), range(3, 6)] +``` - fps = data.load_from_json_file("test/data/test_fps.json", data.WifiFingerprint) - assert len(fps) == 50 - assert fps[0].latitude is None - assert fps[0].longitude is None +### Example 10 +[Code] +def duration(string): + match = re.match(PATTERN, string) + if not match: + raise ValueError('Invalid duration: {}'.format(string)) - # when - fps_new = mf.fill_latitude_longitude(fps, posis) + suffix = match.group(2) + if suffix not in SUFFIX_MAP: + raise ValueError('Invalid duration suffix: {}'.format(string)) - # then - assert fps_new[0].latitude is not None - assert fps_new[0].longitude is not None + return int(match.group(1)) * SUFFIX_MAP[suffix] [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: list and dictionary traversal, key-value pair matching, test case creation difficulty: medium +concepts: regular expression for pattern matching, dictionary lookup for suffix mapping, error handling with exceptions [Instruction] -Write a Python function named `fill_lat_long_for_fingerprints` that enriches or updates a list of WiFi fingerprint objects with latitude and longitude information based on matching position IDs. +Create a Python function that converts a string representation of a time duration into its total number of seconds. +The input string should follow a format that includes a numeric part followed by a time unit suffix without any spaces between. +The time unit can be seconds in the form of 's' or 'sec', minutes in the form of 'm' or 'min', hours in the form of 'h' or 'hour', and days in the form of 'd' or 'day' (units are in small case). +If the input format is invalid, the function should raise a ValueError. -Input: Two lists, `fingerprints` and `positions`, where each entry is an object or a dictionary. Each `fingerprint` has `pos_id`, `latitude`, and `longitude` keys, and each position has `pos_id`, `latitude`, and `longitude` keys. +For example, the function should pass the following assertions: -Additionally, write a test function to validate this behavior using example data files for fingerprints and positions. - -[Solution] +```python +assert parse_duration('15s') == 15 +assert parse_duration('2min') == 120 +assert parse_duration('1h') == 3600 +assert parse_duration('1d') == 86400 +try: + parse_duration('30') +except ValueError as e: + pass +``` [Response] -To create the function, we'll iterate over each fingerprint list and search for a matching `pos_id` in the `positions` list. When a match is found, we'll update the `latitude` and `longitude` values of the fingerprint with the corresponding values from the position. +The desired function aims to transform a string representation into its total number of seconds. The implementation can be split into two parts: -Here's how you can implement the function: +Input validation: -```python -def fill_lat_long_for_fingerprints(fingerprints, positions): - """ - Enriches or updates a list of WiFi fingerprint objects with latitude and longitude - information based on matching position IDs. - """ - pos_id_to_lat_long = {pos['pos_id']: (pos['latitude'], pos['longitude']) for pos in positions} +1. The input string includes two parts: a number as the prefix and the unit as the suffix. +2. We can use regular expression to assert the prefix must be a number (d+) and suffix must be one of s/sec/m/min/h/hour/d/day. +3. If the input string cannot the regular expression, we raise a ValueError. - for fingerprint in fingerprints: - if fingerprint['pos_id'] in pos_id_to_lat_long: - lat, long = pos_id_to_lat_long[fingerprint['pos_id']] - fingerprint['latitude'] = lat - fingerprint['longitude'] = long -``` +Result computation: -To test the function given the example data files, you can define a test function that loads the data, calls the function, and asserts the expected behavior. Let's assume each data file contains a list of dictionaries. Here's an example of how you might structure the test function: +1. Based on the matches of regular expression, we capture the number and unit part. +2. We convert the unit to seconds: 's'/'sec' -> 1 second, 'm'/'min' -> 60 seconds, 'h'/'hour' -> 3600 seconds, 'd'/'day' -> 86400 seconds. +3. We multiply the seconds of the unit and the prefix number and return it as an integer. + +Here's how you can implement this: ```python -def test_fill_lat_long_for_fingerprints_from_files(fingerprints_file_path, positions_file_path): - # Load data from files - with open(fingerprints_file_path, 'r') as f: - fingerprints = json.load(f) - with open(positions_file_path, 'r') as f: - positions = json.load(f) - - fill_lat_long_for_fingerprints(fingerprints, positions) - - positions_dict = {position['pos_id']: position for position in positions} - for fingerprint in fingerprints: - expected_latitude = positions_dict[fingerprint['pos_id']]['latitude'] - expected_longitude = positions_dict[fingerprint['pos_id']]['longitude'] - assert fingerprint['latitude'] == expected_latitude, \ - f"Latitude for pos_id {fingerprint['pos_id']} should be updated to {expected_latitude}" - assert fingerprint['longitude'] == expected_longitude, \ - f"Longitude for pos_id {fingerprint['pos_id']} should be updated to {expected_longitude}" - - print("All tests passed!") - -# Example usage, if the data files are named 'fingerprints.json' and 'positions.json' -# test_fill_lat_long_for_fingerprints_from_files('fingerprints.json', 'positions.json') +import re + +def parse_duration(duration_str): + pattern = re.compile(r"(\d+)(s|sec|m|min|h|hour|d|day)$") + match = pattern.match(duration_str) + if not match: + raise ValueError(f"Invalid duration: {duration_str}") + value, suffix = match.groups() + value = int(value) + suffix_map = { + 's': 1, 'sec': 1, + 'm': 60, 'min': 60, + 'h': 3600, 'hour': 3600, + 'd': 86400, 'day': 86400 + } + return value * suffix_map[suffix] ``` -The test function works by asserting that the latitude and longitude values in the fingerprints have been updated as expected. You can then run the test function with the appropriate file paths to validate the behavior. - -### Example 10 -[Code] -def imdb(path=Path("data/aclImdb/")): - import pickle +[Tests] +You can test your code on various valid and invalid inputs to ensure that the function behaves as expected: +```python +# Test additional valid inputs with different suffix representations +assert parse_duration('10sec') == 10, "Check for 'sec' suffix" +assert parse_duration('3min') == 180, "Check for 'min' suffix" +assert parse_duration('5h') == 18000, "Check for 'h' suffix" +assert parse_duration('2hour') == 7200, "Check for 'hour' suffix" +assert parse_duration('1day') == 86400, "Check for 'day' suffix" + +# Test edge cases +assert parse_duration('0s') == 0, "Zero duration" +assert parse_duration('0min') == 0, "Zero minutes" + +# Test inputs with leading zeros +assert parse_duration('001s') == 1, "Leading zeros" +assert parse_duration('010min') == 600, "Leading zeros with minutes" + +# Test case sensitivity and invalid suffix handling +def assert_value_error(duration_str): try: - return pickle.load((path / "train-test.p").open("rb")) - except FileNotFoundError: + parse_duration(duration_str) + except ValueError: pass - CLASSES = ["neg", "pos", "unsup"] +assert_value_error('10S') +assert_value_error('2hours') +assert_value_error('5 mins') +assert_value_error('h1') +assert_value_error('100') + +# Test extremely large numbers +assert parse_duration('1000000s') == 1000000, "Very large number of seconds" +``` + +### Example 11 +[Code] +class Restaurant(): + + def __init__(self, name, cuisine_type): + self.name = name.title() + self.cuisine_type = cuisine_type + self.number_served = 0 + + def describe_restaurant(self): + msg = f"{self.name} tiene los mejores {self.cuisine_type}." + print(f"\n{msg}") - def get_texts(path): - texts, labels = [], [] - for idx, label in tqdm(enumerate(CLASSES)): - for fname in tqdm((path / label).glob("*.txt"), leave=False): - texts.append(fname.read_text()) - labels.append(idx) - return texts, np.asarray(labels) + def open_restaurant(self): + msg = f"{self.name} EstΓ‘ Abierta. Β‘Adelante! \n" + print(f"\n{msg}") - trXY = get_texts(path / "train") - teXY = get_texts(path / "test") - data = (trXY, teXY) - pickle.dump(data, (path / "train-test.p").open("wb")) - return data + def set_number_served(self, number_served): + self.number_served = number_served #AquΓ­ establecemos la cantidad de clientes atendidos + + def increment_number_served(self, additional_served): + self.number_served += additional_served #AquΓ­ incrementamos la cantidad de clientes atendidos [Property] -category: code generation (function implementation) +category: class implementation language: Python -concepts: data loading given a specific directory structure, file I/O, dataset manipulation with Numpy, caching with pickle, progress display using tqdm difficulty: hard +concepts: object-oriented programming, string formatting and manipulation, encapsulation, input validation [Instruction] -Develop a Python function `load_imdb_dataset` that loads the IMDb dataset for sentiment analysis from a specified directory. Here is the directory structure: - -``` -data/ - aclImdb/ - train/ - neg/ - *.txt (negative review text files) - pos/ - *.txt (positive review text files) - unsup/ - *.txt (unsupervised/unlabeled review text files) - test/ - neg/ - *.txt (negative review text files) - pos/ - *.txt (positive review text files) - unsup/ - *.txt (unsupervised/unlabeled review text files) - train-test.p (pickled file containing the formatted dataset, may or may not exist) -``` +Develop a Python class `StringEnhancer` that offers advanced tools for managing a string. The class should maintain an internal state that is initially an empty string. The class should support extending the string with additional prefix or suffix, converting the entire string to lowercase or uppercase, and altering specific characters within the string. All the operations are communicated through a central method: `apply_operation(self, operation: str)`. -The function should return a tuple of two elements, each being a tuple of two elements. The first element of each tuple is a list of review texts, and the second element is a numpy array of labels (0 for negative, 1 for positive, and 2 for unsupervised). The function should also implement caching for the entire return value using `pickle` to avoid redundant processing, as shown in the directory structure. +The method `apply_operation` interprets the operations specified in the form of `"[opname]:[arg]"`. Here are the detailed formats for each operation, explained through examples: +- Prefix addition: `apply_operation("prefix:Hello, ")` +- Suffix addition: `apply_operation("suffix: World!")` +- Lowercase conversion: `apply_operation("lower:")` +- Uppercase conversion: `apply_operation("upper:")` +- Character replacement: `apply_operation("replace:a->e")` -Use `tqdm` to display progress bars for loading the dataset. - -[Solution] +The class should handle erroneous inputs gracefully, such as incorrect operation formats or arguments. Additionally, implement the `__str__` method to return a formatted string showing the current value of the string and a list of operations that have been applied to it. [Response] -To create the `load_imdb_dataset` function as specified, we'll follow these steps: +The goal is to create a Python class `StringEnhancer` capable of managing a string through various operations, using a central method `apply_operation`. The class also needs to handle errors gracefully and record the operation history. Here are the key steps to implement this class: -1. Before processing the dataset, check if the `train-test.p` pickled file exists in the data directory. If it does, load it using `pickle` and return its content immediately. -2. If the cached file does not exist, iterate over the files in the `train` and `test` directories, loading the review texts and labels. Use the directory names (`neg`, `pos`, `unsup`) to determine the labels. Leverage `pathlib` for file system navigation. -3. Use `tqdm` to display progress bars for loading the dataset. -4. After processing the dataset, save the resulting structure to the `train-test.p` file using `pickle` for future use. -5. Return the loaded dataset. +- Initialize the class with an empty string and an empty list to store the operations. +- Inside the `apply_operation` method: + - Parse the operation string to determine the operation type and argument (if any). + - Gracefully handle any errors in the input format by printing an error message without actually raising an exception. + - Execute the correct string manipulation based on the operation. + - Append the operation to the records list. +- Implement the `__str__` method by concatenating the current string value and the recorded operations. -Here's the implementation: +Here's how you can implement the `StringEnhancer` class: ```python -import numpy as np -import pickle -from tqdm import tqdm -from pathlib import Path +class StringEnhancer: + def __init__(self): + self.string = "" + self.operations = [] + + def apply_operation(self, operation: str): + parts = operation.split(":") + if len(parts) == 1: + print("Invalid operation format. Expected '[opname]:[arg]'") + return + + opname, arg = parts[0], parts[1] + if opname == "prefix": + self.string = arg + self.string + elif opname == "suffix": + self.string += arg + elif opname == "lower": + self.string = self.string.lower() + elif opname == "upper": + self.string = self.string.upper() + elif opname == "replace": + if "->" not in arg: + print("Invalid replace operation format. Expected 'replace:a->b'") + return + old_char, new_char = arg.split("->") + self.string = self.string.replace(old_char, new_char) + else: + print(f"Unknown operation: {opname}") -def load_imdb_dataset(data_dir="data"): - """ - Load the IMDb dataset from the specified directory. - """ - data_path = Path(data_dir) - pickled_data_path = data_path / 'train-test.p' - - # Check if the pickled dataset exists and load it - if pickled_data_path.exists(): - with pickled_data_path.open('rb') as file: - return pickle.load(file) - - # Helper function to read reviews and labels - def read_reviews_and_labels(path): - reviews = [] - labels = [] - for label, label_dir in [(0, 'neg'), (1, 'pos'), (2, 'unsup')]: - txt_paths = list((path / label_dir).glob('*.txt')) - for text_file in tqdm(txt_paths, desc=f'Loading {label_dir}'): - reviews.append(text_file.read_text(encoding='utf-8')) - labels.append(label) - return reviews, np.array(labels) - - # Load training and testing data - train_reviews, train_labels = read_reviews_and_labels(data_path / 'aclImdb' / 'train') - test_reviews, test_labels = read_reviews_and_labels(data_path / 'aclImdb' / 'test') - - # Package the dataset - dataset = ((train_reviews, train_labels), (test_reviews, test_labels)) - - # Pickle the dataset for future use - with pickled_data_path.open('wb') as file: - pickle.dump(dataset, file) - - return dataset + self.operations.append(operation) + + def __str__(self): + operation_history = ', '.join(self.operations) + return f"Current String: {self.string}\nOperations: {operation_history}" ``` -This function assumes the directory structure you provided and uses `tqdm` to show progress. Note that it also uses the `encoding='utf-8'` parameter when opening files to ensure compatibility with text files that may contain characters outside of the ASCII range. Before running this function, ensure you have the necessary libraries installed (`tqdm`, `numpy`) and the IMDb dataset properly placed in the specified directory. +[Tests] +To test the class, you should apply different operations and check if the string and operation history are updated correctly: -### Example 11 +```python +enhancer = StringEnhancer() +enhancer.apply_operation("prefix:Hello") +enhancer.apply_operation("suffix:, World!") +enhancer.apply_operation("lower:") +enhancer.apply_operation("replace:o->a") + +expected_operation_history = ["prefix:Hello", "suffix:, World!", "lower:", "replace:o->a"] +expected_operation_history_str = ', '.join(expected_operation_history) +expected_str_value = "hella, warld!" +assert enhancer.operations == expected_operation_history +assert str(enhancer) == f"Current String: {expected_str_value}\nOperations: {expected_operation_history_str}" +``` + +### Example 12 [Code] -def decompress(self): - source = self.compressed - if isinstance(source, (bytes, bytearray)): - return self.decompress_bytes() - pos = 0 - node = self.root - res = bytearray() +while cur_num < len(sorted_importances): + cluster_idx = int(sorted_importances[cur_num][1]) + filter_idx = int(sorted_importances[cur_num][2]) - while pos < len(source): - code = int(source[pos]) - child = node.children[code] - if child.is_leaf: - res.append(child) - node = self.root - else: - node = child - pos += 1 + if tmp_pruning_quotas[cluster_idx] > 0: + tmp_pruning_quotas[cluster_idx] -= 1 + else: + cur_num += 1 + continue - return bytes(res) + cluster = self.pruned_module_groups_info.get_cluster_by_id(cluster_idx) + for node in cluster.elements: [Property] -category: code generation (class and function implementation) +category: function implementation language: Python -concepts: decoding Huffman compressed data, binary tree traversal, byte array processing -difficulty: medium +difficulty: easy +concepts: iteration with a while loop, index-based data access, conditional branching, nested loops [Instruction] -Write a Python function `huffman_decompress` that takes two arguments: `compressed_data`, a byte array of 0/1 sequence representing Huffman compressed data, and `huffman_tree`, a Huffman tree that stores the actual data (each leaf is associated with a character). The function should return the original uncompressed data as `bytes`. Define the `HuffmanNode` class yourself as a general tree structure with a `children` list. +Create a Python function that identifies all pairs of elements within a list of integers where the second element of the pair is at least greater than the first by `growth_ratio` and the indices of the elements are within a specified `maximum_distance` from each other. + +[Response] +You can write a Python function `find_pairs` that takes three parameters: + +- `numbers`: The list of integers. +- `growth_ratio`: A non-negative `float` that specifies the ratio by which the second element should be greater than the first. +- `maximum_distance`: A non-negative integer that indicates the maximum index distance allowed between the two elements in the pair. + +The function will return a list of tuples where each tuple contains the indices of the pair of elements that satisfy the conditions. + +Here is how you can implement this function: -[Solution] ```python -class HuffmanNode: - def __init__(self, value=None): - self.value = value - self.children = [] +from typing import List, Tuple -def huffman_decompress(compressed_data, huffman_tree): - result = [] - node = huffman_tree - for bit in compressed_data: - if len(node.children) > 0: - node = node.children[int(bit)] - if len(node.children) == 0: - result.append(node.value) - node = huffman_tree +def find_pairs(numbers: List[int], growth_ratio: float, maximum_distance: int) -> List[Tuple[int, int]]: + """ + Find pairs of elements within a list of integers where the second element is at least `growth_ratio` greater than the first and the indices are within `maximum_distance` from each other. + """ + pairs: List[int] = [] + for i in range(len(numbers)): + # For each number, look ahead up to `maximum_distance` elements + for j in range(i + 1, min(i + 1 + maximum_distance, len(numbers))): + if numbers[j] - numbers[i] >= growth_ratio: + pairs.append((i, j)) + return pairs +``` - return "".join(result).encode() +This function iterates over the pairs of elements in the list whose indices satisfy the distance constraint and stores the valid pairs that meet the growth ratio condition. -# Tests -root = HuffmanNode() -root.children = [HuffmanNode('A'), HuffmanNode('B')] -root.children[0].children = [HuffmanNode('C'), HuffmanNode('D')] -root.children[1].children = [HuffmanNode('E'), HuffmanNode('F')] -compressed_data = bytearray([0, 1, 0, 0, 1, 1]) -# 01 -> D, 00 -> C, 11 -> F -assert huffman_decompress(compressed_data, root) == b'DCF' +[Tests] +To ensure the function's correct behavior, you can compare your implementation against a brute-force approach that leverages list comprehension to first generate all possible pairs and then filter them based on the conditions. The tests should cover both the basic functionality where pairs are found, and edge cases where no pairs are found, while also altering the `growth_ratio` and `maximum_distance` parameters to cover more cases: + +```python +def brute_force_pairs(numbers: List[int], growth_ratio: float, maximum_distance: int) -> List[Tuple[int, int]]: + return [ + (i, j) + for i in range(len(numbers)) + for j in range(i + 1, min(i + 1 + maximum_distance, len(numbers))) + if numbers[j] - numbers[i] >= growth_ratio + ] + +test_cases = [ + ([30, 15, 60, 7, 77, 13, 10], 8, 3), + ([30, 15, 60, 7, 77, 13, 10], 0.2, 3), + ([30, 15, 60, 7, 77, 13, 10], 30, 0), + ([30, 15, 60, 7, 77, 13, 10], 0, 2), + ([], 1, 3), + ([8, 8, 8, 8], 1, 0) +] + +for numbers, growth_ratio, maximum_distance in test_cases: + expected_pairs = brute_force_pairs(numbers, growth_ratio, maximum_distance) + assert find_pairs(numbers, growth_ratio, maximum_distance) == expected_pairs ``` -[Response] -[Reasoning] -To accomplish the Huffman decompression, we first need to define the `HuffmanNode` class, which, as required, should be a general tree structure with a `children` list and an optional `value` attribute to store the character associated with a leaf node. +### Example 13 +[Code] +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) -Then, we'll implement the `huffman_decompress` function. This function will traverse the Huffman tree according to the bits in `compressed_data`. When it encounters a leaf node, the associated character will be appended to the result, and the traversal restarts from the root of the tree. This process continues until the entire `compressed_data` is processed. + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) -You said that `compressed_data` is a byte array of 0/1 sequence, so I assume that each byte, not each bit, represents a step in the Huffman tree. In practice, it is more efficient + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res -[Implementation] -Here's how we can implement this: +[Property] +category: function implementation +language: Python +difficulty: medium +concepts: top-k accuracy computation, tensor operations in PyTorch, batch processing, default arguments -```python -class HuffmanNode: - def __init__(self, value=None): - self.value = value - self.children = [] +[Instruction] +Create a Python function `calculate_top_k_accuracy` using PyTorch that calculates the top-k accuracy from a batch of predictions and actual labels. It should accept two tensors, `predictions` and `targets`, and an integer `k` (default 3). The function should return the average top-k accuracy as a float, determining if the true label is within the top `k` predictions for each input. -def huffman_decompress(compressed_data, huffman_tree): - # Store the decompressed characters here - result = [] +[Response] +The function will take the following parameters as input: - # Start traversal from the root of the Huffman tree - node = huffman_tree +- `predictions`: a tensor of shape `(batch_size, num_classes)` that contains the predicted scores for each class. +- `targets`: a tensor of shape `(batch_size,)` that contains the true class labels for each input. +- `k`: an integer representing the number of top predictions to consider for computing the accuracy. The default value should be 3. - # Each byte represents a 0/1 bit in `compressed_data` - for bit in compressed_data: - # Check if the node has children - if len(node.children) > 0: - node = node.children[int(bit)] - # After moving to the child, check if it's a leaf node - if len(node.children) == 0: - result.append(node.value) - # Reset to start from the root for the next character - node = huffman_tree +Here is how you can write this function using PyTorch: - # Convert the result to bytes - return "".join(result).encode() +```python +import torch + +def calculate_top_k_accuracy(predictions: torch.Tensor, targets: torch.Tensor, k: int = 3) -> float: + # Ensure this function works even if k is greater than the number of classes + k = min(k, predictions.size(1)) + _, top_k_indices = predictions.topk(k, dim=1) + target_expanded = targets.unsqueeze(1).expand_as(top_k_indices) + correct_predictions = top_k_indices.eq(target_expanded) + correct = correct_predictions.any(dim=1) + accuracy = correct.float().mean() + return accuracy.item() ``` -[Explanation] -In this implementation, we assume that `compressed_data` is a byte array where each byte represents a 0/1 bit in the Huffman tree. We traverse the tree according to the bits in `compressed_data`, and when we reach a leaf node, we append the associated character to the result. Finally, we convert the result to `bytes` by concatenating the characters and encoding them. +The function retrieves the top k predictions' indices, aligns shapes for comparison with true labels, computes accuracy by checking matches, and returns the average accuracy as a float. [Tests] -You can test this function with a sample Huffman tree and compressed data to verify that the decompression works correctly: - ```python -root = HuffmanNode() -root.children = [HuffmanNode('A'), HuffmanNode('B')] -root.children[0].children = [HuffmanNode('C'), HuffmanNode('D')] -root.children[1].children = [HuffmanNode('E'), HuffmanNode('F')] -compressed_data = bytearray([0, 1, 0, 0, 1, 1]) -# 01 -> D, 00 -> C, 11 -> F -assert huffman_decompress(compressed_data, root) == b'DCF' +# Test 1: Basic functionality test +k = 1 +labels = torch.tensor([2, 0, 1, 1]) +preds = torch.tensor([ + [0.1, 0.2, 0.9], # yes + [1.0, 0.2, 0.3], # yes + [0.3, 1.1, 0.2], # yes + [0.6, 0.1, 0.3], # no +]) +accuracy = calculate_top_k_accuracy(preds, labels, k) +assert accuracy == 0.75 + +# Test 2: Edge case with k = number of classes +k = 3 +preds = torch.tensor([ + [0.5, -1.2, 0.3], # yes + [-1.0, 0.1, 1.2], # yes +]) +labels = torch.tensor([0, 2]) +accuracy = calculate_top_k_accuracy(preds, labels, k) +assert accuracy == 1.0 + +# Test 3: k greater than the number of classes should still work +preds = torch.tensor([ + [0.1, -0.2], # yes + [-0.3, 0.4], # yes +]) +labels = torch.tensor([1, 0]) +accuracy = calculate_top_k_accuracy(preds, labels, k=5) +assert accuracy == 1.0 ``` -### Example 12 +### Example 14 [Code] -def format_size(num): - """http://stackoverflow.com/a/1094933 - """ - for x in ['bytes', 'KB', 'MB', 'GB']: - if num < 1024.0 and num > -1024.0: - return "%3.1f%s" % (num, x) - num /= 1024.0 - return "%3.1f%s" % (num, 'TB') -assert format_size(1024**2 - 1) == '1024.0KB' -assert format_size(1024*512) == '512.0KB' +def alternatingCharacters(s): + i = 0 + j = 1 + count = 0 + while j= 0 else f"-{formatted_size}" - -# Tests -test_data = [ - (1024 ** 2, "1.0MB"), - (1024 ** 2 - 1, "1024.0KB"), - (1024 ** 3 * 512, "512.0GB"), - (1024 ** 4 * 512, "512.0TB"), -] -for size, expected_output in test_data: - assert format_size(size) == expected_output - assert format_size(-size) == f"-{expected_output}" -``` +The function should return the total count of matching characters. This implementation requires managing two pointers for the characters in `str1` and `str2`, respectively, and using a conditional loop to handle different stopping criteria based on the `strict` parameter. [Response] -[Reasoning] -To create such a function, we can use a loop to iterate through a predefined list of units (`['bytes', 'KB', 'MB', 'GB', 'TB']`) and check if the given number of bytes fits into the current unit by comparing it with 1024.0 and -1024.0 (for handling negative sizes). We'll keep dividing the number by 1024.0 until the correct unit is found where the number is less than 1024.0. Then, we'll format the result using the appropriate unit and the rounded size. +We can use two pointers, one for each string, to iterate over the characters from the end of str1 and the beginning of str2. The loop continues either until the characters do not match (for both strict and non-strict cases) or until it exhausts one of the strings. The `strict` flag controls the behavior when a mismatch is encountered. -[Implementation] ```python -def format_size(size_bytes): - units = ['bytes', 'KB', 'MB', 'GB', 'TB'] - if size_bytes == 0: - return "0.0bytes" # Directly return for 0 bytes to avoid division by zero in loop - num = abs(size_bytes) - unit = units[0] +def count_conditional_matches(str1, str2, strict): + # Initialize pointers for the end of str1 and the start of str2 + i, j = len(str1) - 1, 0 + count = 0 - for u in units[1:]: - if num < 1024.0: + # Loop as long as i is non-negative and j is less than the length of str2 + while i >= 0 and j < len(str2): + if str1[i] == str2[j]: + count += 1 + i -= 1 + j += 1 + elif strict: break - num /= 1024.0 - unit = u - - formatted_size = f"{num:.1f}{unit}" - return formatted_size if size_bytes >= 0 else f"-{formatted_size}" + else: + i -= 1 + return count ``` -[Explanation] -This implementation handles both positive and negative file sizes and returns the formatted size as a string, even though negative sizes are not common in real-world scenarios. - [Tests] -You can write the following test cases to ensure it works correctly: +The function can be tested with various inputs to ensure it handles both strict and non-strict cases correctly: ```python -test_data = [ - (1024 ** 2, "1.0MB"), - (1024 ** 2 - 1, "1024.0KB"), - (1024 ** 3 * 512, "512.0GB"), - (1024 ** 4 * 512, "512.0TB"), -] -for size, expected_output in test_data: - assert format_size(size) == expected_output - assert format_size(-size) == f"-{expected_output}" +# Test function with matching characters +assert count_conditional_matches("hello", "o", True) == 1 +assert count_conditional_matches("fancy", "nancy", False) == 2 +assert count_conditional_matches("abcde", "e", False) == 1 +assert count_conditional_matches("helicopter", "ranges", True) == 1 + +# Test function with no matches +assert count_conditional_matches("apple", "banana", True) == 0 +assert count_conditional_matches("abcdef", "xyzab", False) == 0 + +# Test when both strings are empty +assert count_conditional_matches("", "", True) == 0 +assert count_conditional_matches("", "", False) == 0 + +# Test when one string is empty +assert count_conditional_matches("nonempty", "", True) == 0 +assert count_conditional_matches("", "nonempty", False) == 0 ``` -These test cases cover both positive and negative sizes across different units to validate the correctness of the `format_size` function. - -### Example 13 +### Example 15 [Code] -def unify_stringlist(L: list): - """ Adds asterisks to strings that appear multiple times, so the resulting - list has only unique strings but still the same length, order, and meaning. - For example: - unify_stringlist(['a','a','b','a','c']) -> ['a','a*','b','a**','c'] +def oxford_comma_text_to_list(phrase): + """Examples: + - 'Eeeny, Meeny, Miney, and Moe' --> ['Eeeny', 'Meeny', 'Miney', 'Moe'] + - 'Black and White' --> ['Black', 'White'] + - 'San Francisco and Saint Francis' --> + ['San Francisco', 'Saint Francisco'] """ - assert(all([isinstance(l,str) for l in L])) - return [L[i]+"*"*L[:i].count(L[i]) for i in range(len(L))] -assert unify_stringlist(list("abc")) == list("abc") + items = [] + for subphrase in phrase.split(', '): + items.extend( + [item.strip() for item in subphrase.split(' and ')]) + return items [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: list comprehension, unique string creation, precondition checking for type validation difficulty: easy +concepts: string manipulation and parsing, list comprehension, iterative list extension, handling text connectors [Instruction] -Design a Python function to encode a list of strings into a unique list with the same length, order, and meaning. The new list is formed by appending asterisks (*) to duplicate strings. For instance, given ['a', 'a', 'b', 'a', 'c'], the function should return ['a', 'a*', 'b', 'a**', 'c'], while for ['a', 'b', 'c'] it should return ['a', 'b', 'c'] without any changes. Make a precondition check through `assert` to ensure the input is a list of strings. - -Solve the problem in two lines of code, one for the precondition check and the other for the main logic. - -[Solution] -```python -def encode_strings(input_list): - assert all(isinstance(item, str) for item in input_list), "Input must be a list of strings" - return [s + '*' * input_list[:i].count(s) for i, s in enumerate(input_list)] - -# Tests -assert encode_strings(['a', 'a', 'b', 'a', 'c']) == ['a', 'a*', 'b', 'a**', 'c'] -assert encode_strings(['a', 'b', 'c']) == ['a', 'b', 'c'] -``` +Create a concise Python function that transforms a string into a list of individual items. The input string contains a list of items separated by commas and the last item preceded by an optional "and". For example: +- Input: "Eeeny, Meeny, Miney, Moe" --> Output: ['Eeeny', 'Meeny', 'Miney', 'Moe'] +- Input: "Eeeny, Meeny, Miney, and Moe" --> Output: ['Eeeny', 'Meeny', 'Miney', 'Moe'] +- Input: "San Francisco and San Mateo" --> Output: ['San Francisco', 'San Mateo'] +Remove any leading or trailing whitespace. [Response] -[Reasoning] -You can achieve this functionality by first ensuring the input meets your requirements using an `assert` statement with list comprehension. Then, use another list comprehension to iterate through the input list and append `'*'` to the strings. The number of `'*'` is the number of occurrences of the string before the current index. - -[Implementation] -Here's how you can do it in two lines of Python code: +Here's a concise function `string_to_list` to transform a comma-separated string into a list, replacing " and " with a comma, splitting the result, and stripping whitespace: ```python -def encode_strings(input_list): - assert all(isinstance(item, str) for item in input_list), "Input must be a list of strings" - return [s + '*' * input_list[:i].count(s) for i, s in enumerate(input_list)] +def string_to_list(s): + return [item.strip() for item in s.replace(" and ", ",").split(",") if item.strip() != ""] ``` -[Explanation] -The `enumerate` function provides both the index and the string in each iteration, allowing you to count the occurrences of the string before the current index. +This approach ensures no empty strings are included if there are excess commas or spaces. [Tests] -Now you can test your function with the provided examples: +We can test this function on different input strings (e.g., with and without the "and" connector, different numbers of items): ```python -assert encode_strings(['a', 'a', 'b', 'a', 'c']) == ['a', 'a*', 'b', 'a**', 'c'] -assert encode_strings(['a', 'b', 'c']) == ['a', 'b', 'c'] +assert string_to_list("Eeeny, Meeny, Miney, Moe") == ['Eeeny', 'Meeny', 'Miney', 'Moe'] +assert string_to_list("Eeeny, Meeny, Miney, and Moe") == ['Eeeny', 'Meeny', 'Miney', 'Moe'] +assert string_to_list("San Francisco and San Mateo") == ['San Francisco', 'San Mateo'] +assert string_to_list("Single Item") == ['Single Item'] ``` -Note that although the function is concise, it is not the most efficient solution for large lists, as the `count` method has a time complexity of O(n) for each iteration. For larger lists, a more efficient approach might be to use a dictionary to store the counts of each string and then iterate through the list to append the appropriate number of `'*'` to each string. - -### Example 14 +### Example 16 [Code] -def edit_string_for_tags(tags): - """ - Simplified version of comma separated tags - """ - names = [] - for tag in tags: - name = tag.name - if ',' in name: - names.append('"%s"' % name) - continue - names.append(name) - return ','.join(names) +def get_hashes(buf): + hexdigests = namedtuple('Digests', 'md5 sha1 sha256 sha512 ctph') + if isinstance(buf, str): + buf = open(buf, 'rb').read() + md5 = hashlib.md5(buf).hexdigest() + sha1 = hashlib.sha1(buf).hexdigest() + sha256 = hashlib.sha256(buf).hexdigest() + sha512 = hashlib.sha512(buf).hexdigest() + ctph = ssdeep.hash(buf) + return hexdigests._make((md5, sha1, sha256, sha512, ctph)) [Property] -category: code generation (function implementation) +category: function implementation language: Python -concepts: object attribute access, conditional string formatting, joining strings with a separator difficulty: easy +concepts: binary string hash computation, file content reading, string hashing, named tuples, conditional file content reading [Instruction] -Write a Python function to format a list of tag objects (each with a `.name` attribute) and into a single string containing all tag names separated by commas (no spaces), where tag names containing commas are first enclosed in double quotes. +Implement a Python function that computes different types of hash values for the binary content of a file whose path is defined by an input string. You should calculate 'md5', 'sha1', 'sha256', and 'sha512' and return a `namedtuple` of these fields. When the input string does not correspond to a file, the function should directly compute the hash values of path itself. + +[Response] +To compute the hashes of a file or a string, we'll follow these steps: + +1. If the input is a valid file path, read the binary content of the file. +2. Otherwise, convert the input string to a binary string. +3. Compute hash values of all types using the `hashlib` module. +4. Return the hash values as a `namedtuple` for easy access. + +Here's how you can implement this function: -[Solution] ```python -def format_tags(tag_objects): - formatted_tags = [] +import hashlib +from collections import namedtuple +from pathlib import Path - for tag in tag_objects: - tag_name = tag.name - if ',' in tag_name: - formatted_tags.append(f'"{tag_name}"') - else: - formatted_tags.append(tag_name) +HashValues = namedtuple('HashValues', ['md5', 'sha1', 'sha256', 'sha512']) - return ','.join(formatted_tags) +def compute_hashes(file_path): + try: + with open(file_path, 'rb') as f: + content = f.read() + # Handle cases where the file does not exist or the path is a directory + except (FileNotFoundError, IsADirectoryError): + content = file_path.encode() + + md5_hash = hashlib.md5(content).hexdigest() + sha1_hash = hashlib.sha1(content).hexdigest() + sha256_hash = hashlib.sha256(content).hexdigest() + sha512_hash = hashlib.sha512(content).hexdigest() + + return HashValues(md5_hash, sha1_hash, sha256_hash, sha512_hash) +``` -# Tests -class Tag: - def __init__(self, name): - self.name = name +[Tests] +To test the function, you can provide string inputs to avoid file I/O operations and verify that the hash values are correctly computed: -tags = [Tag('python'), Tag('data science'), Tag('machine learning'), Tag('AI, Robotics')] -formatted_string = format_tags(tags) -assert formatted_string == 'python,data science,machine learning,"AI, Robotics"' +```python +string = 'random string' +binary = string.encode() +computed_hashes = compute_hashes(string) +assert computed_hashes.md5 == hashlib.md5(binary).hexdigest() +assert computed_hashes.sha1 == hashlib.sha1(binary).hexdigest() +assert computed_hashes.sha256 == hashlib.sha256(binary).hexdigest() +assert computed_hashes.sha512 == hashlib.sha512(binary).hexdigest() ``` +### Example 17 +[Code] +from collections import Counter + +class Solution: + def sumOfUnique(self, nums: List[int]) -> int: + return sum(k for k, v in Counter(nums).items() if v == 1) + +[Property] +category: function implementation +language: Python +difficulty: easy +concepts: dictionary-based counting, filtering unique elements with list comprehension, summation of elements based on condition + +[Instruction] +Write a Python function that takes a list from user input and prints the sum of counts for elements that occur more than once. Use a dictionary to count frequencies efficiently. Utilize a list comprehension to extract repeated elements based on the frequency counts, and sum their corresponding counts from the dictionary. + [Response] -[Reasoning] -You can create a Python function that iterates through the tags and checks if the `.name` attribute contains a comma to format it accordingly. +To implement this function, we first need to parse the user input as a list of integers and then count the frequency of each element using a dictionary. We can then calculate the sum of the counts for elements that occur more than once. -[Implementation] -Here's how you could do it: +Here's a Python function that accomplishes this: ```python -def format_tags(tag_objects): - formatted_tags = [] - - for tag in tag_objects: - tag_name = tag.name - if ',' in tag_name: - # If the tag name contains a comma, enclose it in double quotes - formatted_tags.append(f'"{tag_name}"') - else: - formatted_tags.append(tag_name) - - # Join all formatted tag names with commas (no spaces) - return ','.join(formatted_tags) +def sum_repeated_counts(): + nums = list(map(int, input().split())) + freq_dict = {} + for num in nums: + freq_dict[num] = freq_dict.get(num, 0) + 1 + sum_of_repeats = sum(v for k, v in freq_dict.items() if v > 1) + print(sum_of_repeats) ``` -[Explanation] -This function goes through each tag object, checks for commas in the tag's name, and formats it by enclosing it in double quotes if needed. Tags are then joined into a single string, separated by commas without spaces. - [Tests] -To test this function, you can create a simple `Tag` class with a `name` attribute and a list of `Tag` objects: +We can test this program with I/O redirection to simulate user input and output. Here is an example: ```python -class Tag: - def __init__(self, name): - self.name = name +import sys +from io import StringIO + +test_cases = [ + ("1 2 2 3 3 3\n", "5\n"), + ("1 1 1 1 1\n", "5\n"), + ("1 2 3 4 5\n", "0\n"), + ("\n", "0\n") +] -tags = [Tag('python'), Tag('data science'), Tag('machine learning'), Tag('AI, Robotics')] -formatted_string = format_tags(tags) -assert formatted_string == 'python,data science,machine learning,"AI, Robotics"' +for input_data, expected_output in test_cases: + stdin, stdout = sys.stdin, sys.stdout + sys.stdin = StringIO(input_data) + sys.stdout = StringIO() + sum_repeated_counts() + output = sys.stdout.getvalue() + assert output == expected_output ``` -This test ensures that the function correctly formats each tag name and combines them into the expected output string. +### Example 18 -### Example 15 [Code] -def _create_folds_list(data, count): - """ - Creates folds from the given data. +fp=open('cmd.sh','r') +lines=fp.readlines() +for i,line in enumerate(lines): + cmd=line.replace('\n','') + print i,cmd + p=os.popen(cmd) + x=p.read() + p.close() +fp.close - :param data: the data to fold - :param count: the number of folds to create +[Property] +category: program implementation +language: Python +difficulty: medium +concepts: iterating over file lines, command execution from file input, new line character removal - :return: a list of folds - """ +[Instruction] +Develop a Python program that reads a text file containing one shell command per line and executes these commands sequentially. Each command should be sanitized by removing unnecessary newline characters before execution. Provide clear error handling and logging to manage exceptions that might occur during the execution of the commands. - fold_count = len(data) / count - folds = list() +[Response] +To create a Python program that reads and executes shell commands from a text file, we can use Python’s built-in module `subprocess` for running the commands and `logging` for handling logs and errors. The program will involve reading and sanitizing the commands from a file, executing them sequentially, and logging the results or errors. - for fold_index in range(count): - low = int(fold_index * fold_count) - high = int((fold_index + 1) * fold_count) +Below is the Python script that demonstrates this: - fold = data[low:high] - folds.append(fold) +```python +import subprocess +import logging - return folds -assert _create_folds_list(range(12), 2) == [range(0, 6), range(6, 12)] -assert _create_folds_list(range(8), 2) == [range(0, 4), range(4, 8)] -assert _create_folds_list(range(25), -5) == [] -assert _create_folds_list(range(6), 2) == [range(0, 3), range(3, 6)] +logging.basicConfig(level=logging.INFO, format='%(asctime)s:%(levelname)s:%(message)s') -[Property] -category: code generation (function implementation) -language: Python -concepts: data partitioning and slicing, learning from input-output examples -difficulty: medium +def read_commands(file_path): + """Reads a text file and returns a list of commands.""" + try: + with open(file_path, 'r') as file: + commands = file.readlines() + commands = [cmd.strip() for cmd in commands] # Remove newline characters and spaces + return commands + except FileNotFoundError: + logging.error(f"The file {file_path} was not found.") + raise + except Exception as e: + logging.error(f"An error occurred while reading the file: {str(e)}") + raise + +def execute_command(command): + """Executes a single shell command.""" + try: + logging.info(f"Executing command: {command}") + result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + logging.info(f"Command output: {result.stdout.decode().strip()}") + except subprocess.CalledProcessError as e: + logging.error(f"Error executing command '{command}': {e.stderr.decode()}") + except Exception as e: + logging.error(f"An unexpected error occurred: {str(e)}") + +def main(file_path): + """Main function to read commands from a file and execute them.""" + commands = read_commands(file_path) + for command in commands: + execute_command(command) +``` -[Instruction] -Construct a Python function `create_folds(data, count)` to create a list of folds from the given data, where each fold is a subsequence of the original data. The length of each fold should be approximately equal with at most a difference of 1. The function should satisfy the following assertion: +[Tests] +To test this program, we can use Python's `unittest` framework along with `mock_open` to simulate file operations, `patch` to mock `subprocess.run` for testing command execution, and `assertLogs` to verify the logging output. Here's an example of how you can write tests for this program: ```python -assert create_folds(range(12), 2) == [range(0, 6), range(6, 12)] -assert create_folds(range(8), 2) == [range(0, 4), range(4, 8)] -assert create_folds(range(25), -5) == [] -assert create_folds(range(6), 2) == [range(0, 3), range(3, 6)] +import unittest +from unittest.mock import mock_open, patch, MagicMock + +class TestCommandExecutor(unittest.TestCase): + def test_read_commands_success(self): + m = mock_open(read_data="ls\npwd\n") + with patch('builtins.open', m): + result = read_commands("dummy_path") + self.assertEqual(result, ['ls', 'pwd']) + + def test_read_commands_file_not_found(self): + with patch('builtins.open', mock_open()) as mocked_file: + mocked_file.side_effect = FileNotFoundError + with self.assertRaises(FileNotFoundError): + read_commands("nonexistent_path") + + @patch('subprocess.run') + def test_execute_command_success(self, mock_run): + mock_run.return_value = MagicMock(stdout=b"Execution successful", stderr=b"") + execute_command("ls") + mock_run.assert_called_once_with("ls", shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + @patch('subprocess.run') + def test_execute_command_failure(self, mock_run): + error_message = b"Command failed" + mock_run.side_effect = subprocess.CalledProcessError(1, 'ls', stderr=error_message) + with self.assertLogs(level='ERROR') as log: + execute_command("ls") + self.assertIn("Error executing command 'ls'", log.output[0]) + self.assertIn("Command failed", log.output[0]) + +if __name__ == "__main__": + unittest.main() ``` -[Solution] -```python -def create_folds(data, count): - if count <= 0: - return [] +### Example 19 +[Code] +def busca_linear_recursiva(array,elemento): + i = -1 + return busca_recursiva(array,elemento, i) + +def busca_recursiva(array,elemento, i): + i += 1 + if(i == len(array)): + return -1 + if(array[i] == elemento): + return i + return busca_recursiva(array,elemento,i) - data_length = len(data) - fold_size = data_length // count - remainder = data_length % count +[Property] +category: program implementation +language: Python +difficulty: medium +concepts: linear search with recursion, default parameter initialization, multiple base cases in recursion - folds = [] - start_index = 0 +[Instruction] +Design a Python program that not only searches for a target value in a list using recursion but also constructs a trace of the indices it checks. This program should return both the index of the target, if found, or `-1` if not, alongside a list of the indices it examined during the search. The function should handle cases where the target is not present in the list. - for i in range(count): - end_index = start_index + fold_size + (1 if i < remainder else 0) - folds.append(range(start_index, end_index)) - start_index = end_index +[Response] +We can implement a recursive linear search function that returns the index of the target element if found or `-1` if not. To keep track of the indices examined during the search, we can pass a list as an additional parameter to store these indices. - return folds +Here is the Python program that implements this functionality: -# Tests -assert create_folds(range(12), 2) == [range(0, 6), range(6, 12)] -assert create_folds(range(8), 2) == [range(0, 4), range(4, 8)] -assert create_folds(range(25), -5) == [] -assert create_folds(range(6), 2) == [range(0, 3), range(3, 6)] +```python +def search_recursive(arr, target, index, indices_checked): + """Recursive function to search for the target value in the list.""" + if index >= len(arr): + return -1 + indices_checked.append(index) + if arr[index] == target: + return index + return search_recursive(arr, target, index + 1, indices_checked) + +def recursive_search(arr, target): + """Wrapper function to initiate recursive search.""" + indices_checked = [] + result = search_recursive(arr, target, 0, indices_checked) + return result, indices_checked + +# Example usage +if __name__ == "__main__": + array = [2, 3, 4, 8, 5, 9, 1, 6, 5, 7, 5] + target = 6 + result, indices = recursive_search(array, target) + print(f"Target {target} found at index {result} with indices checked: {indices}") ``` -[Response] -[Reasoning] -To create a Python function that generates a list of folds from given data with each fold having approximately equal length (with at most a difference of 1 between any two folds), you can follow these steps: - -1. Return an empty list immediately if `count` is non-positive since it's not possible to create a valid number of folds. -2. Divide the total length of the data by `count`, rounding down, to get the minimum number of items each fold should have. -3. Calculate the remainder to see how many extra items need to be distributed among the folds. They are distributed one by one to the first `remainder` folds. -4. Use a loop to create each fold, adjusting the starting and ending indices based on the calculations above to ensure each fold has the correct number of items. +In this program, the `search_recursive` function recursively searches for the target element in the list and appends the indices checked to the `indices_checked` list. The `recursive_search` function acts as a wrapper to initiate the recursive search and return the result along with the list of indices checked. -[Implementation] -Here's how you can implement this: +[Tests] +To test the implementation, we need to ensure that both the index of the target element and the list of indices checked are correct. Here are the test cases: ```python -def create_folds(data, count): - # Return an empty list if count is non-positive - if count <= 0: - return [] +# Test case 1: Target found in the list +arr1 = [2, 3, 4, 8, 5, 9, 1, 6, 5, 7, 5] +target1 = 6 +expected_result1 = arr1.index(target1) +expected_indices1 = list(range(expected_result1 + 1)) +result1, indices_checked1 = recursive_search(arr1, target1) +assert result1 == expected_result1 +assert indices_checked1 == expected_indices1 + +# Test case 2: Target not found in the list +arr2 = [1, 2, 3, 4, 5] +target2 = 6 +expected_result2 = -1 +expected_indices2 = list(range(len(arr2))) +result2, indices_checked2 = recursive_search(arr2, target2) +assert result2 == expected_result2 +assert indices_checked2 == expected_indices2 + +# Test case 3: Target found at the end of the list +arr3 = [5, 4, 3, 2, 1] +target3 = 1 +expected_result3 = len(arr3) - 1 +expected_indices3 = list(range(expected_result3 + 1)) +result3, indices_checked3 = recursive_search(arr3, target3) +assert result3 == expected_result3 +assert indices_checked3 == expected_indices3 +``` - data_length = len(data) - fold_size = data_length // count - remainder = data_length % count +### Example 20 +[Code] +Data = np.zeros([break_points[-1],size_blocks]) +Data_stacked = np.zeros([break_points[-1]-num_blocks+1, size_blocks*num_blocks]) +cluster_point_list = [] +for counter in xrange(len(break_points)): + break_pt = break_points[counter] + cluster = seg_ids[counter] + if counter == 0: + old_break_pt = 0 + else: + old_break_pt = break_points[counter-1] + for num in xrange(old_break_pt,break_pt): + ##generate the point from this cluster + # print "num is:", num + if num == 0: - folds = [] - start_index = 0 +[Property] +category: program implementation +language: Python +difficulty: medium +concepts: NumPy array initialization, iteration over break points, cluster segmentation, nested loops - for i in range(count): - # Add 1 element to each of the first `remainder` folds - end_index = start_index + fold_size + (1 if i < remainder else 0) - folds.append(range(start_index, end_index)) - start_index = end_index +[Instruction] +Implement a Python program that initializes a NumPy array randomly (0--1) and segments it into clusters based on provided breakpoints. Each cluster is defined by consecutive elements between two breakpoints. The output should be a list of NumPy arrays representing these clusters. - return folds +For example, given a size 10 and breakpoints `[3, 6, 8]`, the program should first initialize a NumPy array of 10 random elements between 0 and 1 and then output three clusters: elements at indices `0-2`, `3-5`, `6-7`, and `8-9`. Ensure that the program handles cases where breakpoints are at the boundaries of the array or if there are no breakpoints, in which case the entire array should be a single cluster. Include error handling to manage invalid inputs such as negative array sizes or breakpoints that are out of range. + +[Response] +Here is the program that achieves what you've described: + +```python +import numpy as np + +def segment_array(data, break_points): + """Segment the given array into clusters based on specified breakpoints.""" + break_points = [0] + break_points + [len(data)] + segmented_arrays = [] + for i in range(len(break_points) - 1): + start_index = break_points[i] + end_index = break_points[i + 1] + segmented_arrays.append(data[start_index:end_index]) + return segmented_arrays + +def create_clusters(size, break_points): + """Generate and segment a random array of given size using breakpoints.""" + if size <= 0: + raise ValueError("Array size must be positive.") + data = np.random.rand(size) + clusters = segment_array(data, break_points) + return clusters + +# Example usage +if __name__ == "__main__": + size = 10 + break_points = [3, 6, 8] + clusters = create_clusters(size, break_points) + for i, cluster in enumerate(clusters): + print(f"Cluster {i}: {cluster}") ``` -[Explanation] -This implementation ensures that each fold has an equal length with at most a difference of 1 between any two folds. It handles edge cases, such as a non-positive `count`, by returning an empty list as specified. +In this implementation, `segment_array` transforms the input data array into a list of NumPy arrays based on the provided breakpoints. The `main` function initializes a random NumPy array of a specified size and calls `segment_array` to segment the array into clusters based on the breakpoints. [Tests] -You can test your code with the provided assertions to verify that the function works correctly: +To test the program, we can verify that the values in each cluster are within the expected range (0--1) and that the clusters are correctly segmented. Here are some test cases: ```python -assert create_folds(range(12), 2) == [range(0, 6), range(6, 12)] -assert create_folds(range(8), 2) == [range(0, 4), range(4, 8)] -assert create_folds(range(25), -5) == [] -assert create_folds(range(6), 2) == [range(0, 3), range(3, 6)] +def test_segmentation(size, break_points): + clusters = create_clusters(size, break_points) + for cluster in clusters: + assert np.all(cluster >= 0) and np.all(cluster <= 1) + break_points = [0] + break_points + [size] + for i in range(len(break_points) - 1): + assert len(clusters[i]) == break_points[i + 1] - break_points[i] + +# Test cases +test_cases = [ + (20, [5, 10, 15]), + (10, [3, 6, 8]), + (15, [3, 6, 12]), + (30, [7, 14, 21, 27]), + (5, [1]), + (10, []), + (50, [10, 20, 30, 40]), +] + +for size, breakpoints in test_cases: + test_segmentation(size, breakpoints) ``` -### Example 16 +### Example 21 [Code] -def extract_text(obj): - """Extracts data, concatenates and removes html tags - from fields listed in a obj.related_source_fields list. - """ - related_source_fields = getattr(obj._meta.model, "related_source_fields", None) - if not related_source_fields: - return - - html_pieces = [] - for source_field in related_source_fields: - field = source_field.get_field(obj.__class__) - field_value = source_field.get_value(obj) - if isinstance(field, StreamField): - field_value = " ".join(field_value) - - html_pieces.append(field_value) +if match_case == 'auto': + if len(set([ch_name.lower() for ch_name in ch_names])) < len(set(ch_names)): + match_case = True + else: + match_case = False - text = " ".join(html_pieces) +if match_case: + picks = mne.pick_channels(ch_names, pick_chs, ordered=ordered) +else: + ch_names = [ch_name.lower() for ch_name in ch_names] + pick_chs = [pick_ch.lower() for pick_ch in pick_chs] + picks = mne.pick_channels(ch_names, pick_chs, ordered=ordered) - text = BeautifulSoup(text, "html5lib").text - return text +return picks [Property] -category: code generation (program implementation) +category: program implementation language: Python -concepts: attribute existence check, string concatenation, HTML tag stripping, graceful error handling difficulty: medium +concepts: case-insensitive string matching, controlled case matching based on an option, list comprehension [Instruction] -I want to pull text data from certain fields of an object---these fields are named in a `related_source_fields` attribute belonging to the object's class. After concatinating the text, the function should strip out any HTML tags to leave just the clean, plain text. Also, the aforementioned attribute may or may not exist. Make sure your program handles this case gracefully. +Construct a Python program that processes a list of strings provided by the user. The first element is a target string, the second is either 'CaseSensitive' or 'CaseInsensitive' to determine the matching approach, and the remaining elements are candidates for comparison. The program should output the number of times the target appears among the candidates according to the specified matching rule. The solution should include robust error handling to validate the matching rule and check that at least two strings are provided before proceeding. -[Solution] +For example, given the input `['apple', 'CaseInsensitive', 'Apple', 'banana', 'APPLE']`, the program should output `2` since 'Apple' and 'APPLE' match the target 'apple' in a case-insensitive manner. [Response] -You can achieve this in Python by checking if the `related_source_fields` attribute exists in the object's class, concatenates the text from the specified fields, and then strips out any HTML tags to leave just the plain text. You can use the BeautifulSoup library from `bs4` to remove HTML tags. +Here's a Python program that fulfills the requirements you specified. It first parses the input list to validate the presence of the target string and matching rule. Then, it compares the target with the candidate strings, adjusting for case sensitivity as specified. Finally, it outputs the count of matches. -First, ensure you have BeautifulSoup installed in your environment. If not, you can install it using pip: +```python +def count_matches(input_list): + if not input_list or len(input_list) < 2: + raise ValueError("Input list must contain at least two items: the target string and the matching rule.") -```sh -pip install beautifulsoup4 -``` + target_string, matching_rule = input_list[0], input_list[1] + if matching_rule not in ['CaseSensitive', 'CaseInsensitive']: + raise ValueError("Matching rule must be either 'CaseSensitive' or 'CaseInsensitive'.") -Here's how you could implement the logic in a Python function: + candidates = input_list[2:] + if matching_rule == 'CaseSensitive': + match_count = sum(1 for candidate in candidates if candidate == target_string) + else: + target_string = target_string.lower() + match_count = sum(1 for candidate in candidates if candidate.lower() == target_string) -```python -from typing import Any -from bs4 import BeautifulSoup - -def extract_and_clean_text(obj: Any) -> str: - """Extract text from specified fields and clean it of HTML.""" - # Check for 'related_source_fields' in the object's class, defaulting to an empty list if not found - fields = getattr(obj.__class__, 'related_source_fields', []) - - # Gather text from each specified field, if the field exists in the object - text_pieces = [] - for field in fields: - if hasattr(obj, field): - field_value = getattr(obj, field, '') - text_pieces.append(field_value) - - # Combine all pieces of text into a single string and clean it of HTML tags - combined_text = " ".join(text_pieces) - cleaned_text = BeautifulSoup(combined_text, "html.parser").get_text() - return cleaned_text + return match_count + +def main(): + input_list = input().split() + print(count_matches(input_list)) + +if __name__ == "__main__": + main() ``` -In this implementation, we use `getattr` with a default value of an empty list to handle the case where `related_source_fields` is not defined. We then gather the text from each specified field, combine all pieces of text into a single string, and clean it of HTML tags using BeautifulSoup. +The `count_matches` function implements the main logic of the program, counting the occurrences of the target string among the candidate strings based on the specified matching rule. The `main` function reads the input list from the user, calls `count_matches`, and prints the count of matches. -Here's an example of how you could use this function: +[Tests] +You can test the program with various inputs to ensure it handles different scenarios correctly: ```python -class ExampleClass: - related_source_fields = ['title', 'description'] - title = '

This is a title

' - description = '

This is a description.

' - -obj = ExampleClass() -print(extract_and_clean_text(obj)) +assert count_matches(['apple', 'CaseInsensitive', 'Apple', 'banana', 'APPLE']) == 2 +assert count_matches(['apple', 'CaseSensitive', 'Apple', 'banana', 'APPLE']) == 0 +assert count_matches(['apple', 'CaseInsensitive']) == 0 +assert count_matches(['apple', 'CaseInsensitive', 'apple', 'APPLE', 'aPpLe', 'banana']) == 3 +assert count_matches(['apple', 'CaseSensitive', 'apple', 'APPLE', 'aPpLe', 'banana']) == 1 + +try: + count_matches(['apple']) + assert False +except ValueError as e: + assert str(e) == "Input list must contain at least two items: the target string and the matching rule." + +try: + count_matches(['apple', 'InvalidRule', 'Apple']) + assert False +except ValueError as e: + assert str(e) == "Matching rule must be either 'CaseSensitive' or 'CaseInsensitive'." ``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9fc8c7b..6b32cd5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ tiktoken~=0.6.0 accelerate>=0.27.2 datasets>=2.17.1 evalplus @ git+https://github.com/evalplus/evalplus.git@25e195e024b614f2671ad9ac5b8fdcd9b95a2b24#egg=evalplus +evoeval~=0.1.0 diff --git a/sanitize.sh b/sanitize.sh new file mode 100755 index 0000000..53f44af --- /dev/null +++ b/sanitize.sh @@ -0,0 +1,41 @@ +#!/bin/bash +set -e + +SOURCE=$1 +TARGET=$2 + +echo "Sanitizing.." +python -m star_align.sanitize_data \ + --data_files $SOURCE \ + --output_file $TARGET \ + --parse_raw_response True \ + --exact_match_dedup True \ + --passing_only True \ + --include_left_failed False + +if [[ -n $DECONTAMINATION ]]; then + echo "Decontaminating.. (saving to decontamination-output)" + python -m star_align.decontamination.find_substrings \ + --dataset_name "json" \ + --output_file $TARGET \ + --output_dir decontamination-output \ + --columns instruction response \ + --data_files $TARGET +fi + +echo "Minihash dedup.." +python -m star_align.minhash_dedup \ + --data_files $TARGET \ + --column instruction \ + --output $TARGET + +python -m star_align.minhash_dedup \ + --data_files $TARGET \ + --column response \ + --output $TARGET + +python -m star_align.minhash_dedup \ + --data_files $TARGET \ + --column code_representation \ + --ignore_empty True \ + --output $TARGET diff --git a/self_ossinstruct_sc2.sh b/self_ossinstruct_sc2.sh new file mode 100755 index 0000000..6a768da --- /dev/null +++ b/self_ossinstruct_sc2.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +echo "MODE: $MODE" +echo "SEED_DATA_FILE: $SEED_DATA_FILE" +echo "INDEX: $INDEX" +echo "MAX_NEW_DATA: $MAX_NEW_DATA" +echo "DIR: $1" + +# if mode is "I->R", num of samples is 10, otherwise 1 +if [[ "$MODE" == "I->R" ]]; then + N_SAMPLES=1 + NUM_FEWSHOTS=1 + NUM_BATCHED_REQUESTS=4096 + ASYNC_MICRO_BATCH_SIZE=16 +else + N_SAMPLES=1 + NUM_FEWSHOTS=8 + NUM_BATCHED_REQUESTS=4096 + ASYNC_MICRO_BATCH_SIZE=8 +fi + +echo "N_SAMPLES: $N_SAMPLES" +echo "NUM_FEWSHOTS: $NUM_FEWSHOTS" +echo "NUM_BATCHED_REQUESTS: $NUM_BATCHED_REQUESTS" +echo "ASYNC_MICRO_BATCH_SIZE: $ASYNC_MICRO_BATCH_SIZE" + +COMMAND="python -m star_align.self_ossinstruct \ + --async_micro_batch_size $ASYNC_MICRO_BATCH_SIZE \ + --use_vllm_server True \ + --instruct_mode '$MODE' \ + --seed_data_files $SEED_DATA_FILE \ + --max_new_data $MAX_NEW_DATA \ + --tag sc2-${NUM_FEWSHOTS}shot \ + --temperature 0.7 \ + --seed_code_start_index $INDEX \ + --model bigcode/starcoder2-15b \ + --num_fewshots $NUM_FEWSHOTS \ + --num_batched_requests $NUM_BATCHED_REQUESTS \ + --num_sample_per_request $N_SAMPLES \ + --save_dir $1" + +if [[ -n "$2" ]]; then + COMMAND="$COMMAND --continue_from $2" +fi + +echo "Running command: $COMMAND" +eval $COMMAND diff --git a/self_ossinstruct_sc2_parallel.sh b/self_ossinstruct_sc2_parallel.sh new file mode 100755 index 0000000..a043c2c --- /dev/null +++ b/self_ossinstruct_sc2_parallel.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +echo "MODE: $MODE" +echo "SEED_DATA_FILE: $SEED_DATA_FILE" +echo "INDEX: $INDEX" +echo "MAX_NEW_DATA: $MAX_NEW_DATA" +echo "DIR: $1" + +NUM_GPUS=$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits | head -n 1) + +DATA_CHUNK_SIZE=$(($MAX_NEW_DATA / $NUM_GPUS)) +REMAINDER=$(($MAX_NEW_DATA % $NUM_GPUS)) + +if [[ "$MODE" == "I->R" ]]; then + N_SAMPLES=1 + NUM_FEWSHOTS=1 + NUM_BATCHED_REQUESTS=4096 + ASYNC_MICRO_BATCH_SIZE=16 +else + N_SAMPLES=1 + NUM_FEWSHOTS=8 + NUM_BATCHED_REQUESTS=4096 + ASYNC_MICRO_BATCH_SIZE=8 +fi + +echo "N_SAMPLES: $N_SAMPLES" +echo "NUM_FEWSHOTS: $NUM_FEWSHOTS" +echo "NUM_BATCHED_REQUESTS: $NUM_BATCHED_REQUESTS" +echo "ASYNC_MICRO_BATCH_SIZE: $ASYNC_MICRO_BATCH_SIZE" + +PIDS=() +function killall_pids { + for pid in ${PIDS[@]}; do + kill $pid + done +} +trap killall_pids SIGINT SIGTERM + +for (( GPU_ID=0; GPU_ID<$NUM_GPUS; GPU_ID++ )) +do + START_INDEX=$(($INDEX + $GPU_ID * $DATA_CHUNK_SIZE)) + if [[ $GPU_ID -lt $REMAINDER ]]; then + CHUNK_SIZE=$(($DATA_CHUNK_SIZE + 1)) + else + CHUNK_SIZE=$DATA_CHUNK_SIZE + fi + END_INDEX=$(($START_INDEX + $CHUNK_SIZE - 1)) + + echo "Starting process for GPU $GPU_ID with data from $START_INDEX to $END_INDEX..." + + OUTDIR="$1/$GPU_ID" + mkdir -p $OUTDIR + + CUDA_VISIBLE_DEVICES=$GPU_ID python -m star_align.self_ossinstruct \ + --async_micro_batch_size $ASYNC_MICRO_BATCH_SIZE \ + --use_vllm_server False \ + --instruct_mode "$MODE" \ + --seed_data_files $SEED_DATA_FILE \ + --max_new_data $CHUNK_SIZE \ + --tag sc2-${NUM_FEWSHOTS}shot \ + --temperature 0.7 \ + --seed_code_start_index $START_INDEX \ + --model bigcode/starcoder2-15b \ + --num_fewshots $NUM_FEWSHOTS \ + --num_batched_requests $NUM_BATCHED_REQUESTS \ + --num_sample_per_request $N_SAMPLES \ + --save_dir $OUTDIR & + PIDS+=($!) +done + +wait + +if [[ $? -ne 0 ]]; then + echo "Error in one of the processes. Exiting... Check logs for more details." + exit 1 +fi + +# dir for final res +FINAL="$1/final" +FINAL_FILE="$FINAL/aggregated-${MODE}.jsonl" + +echo "All processes finished. Aggregating results... to $FINAL_FILE" + +# aggregate! +mkdir -p $FINAL +touch $FINAL_FILE + +for (( GPU_ID=0; GPU_ID<$NUM_GPUS; GPU_ID++ )) +do + # get first file for dir + FILE=$(ls $1/$GPU_ID | head -n 1) + cat $1/$GPU_ID/$FILE >> $FINAL_FILE +done + +echo "Done!" diff --git a/src/star_align/code_exec_server b/src/star_align/code_exec_server index 32af017..5bbc529 160000 --- a/src/star_align/code_exec_server +++ b/src/star_align/code_exec_server @@ -1 +1 @@ -Subproject commit 32af017251fd6cc98382f5295e398f3680e7fd79 +Subproject commit 5bbc52911705ef39870f99ae6b923ae65012ff14 diff --git a/src/star_align/collect_snippets.py b/src/star_align/collect_snippets.py new file mode 100644 index 0000000..0ce033b --- /dev/null +++ b/src/star_align/collect_snippets.py @@ -0,0 +1,603 @@ +import json +import random +import re +import textwrap +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import Generator, cast + +from datasets import Dataset, concatenate_datasets, load_dataset +from tqdm.auto import tqdm +from transformers import HfArgumentParser +import star_align + +print("Warning: ignoring warnings") +warnings.filterwarnings("ignore") + + +@dataclass(frozen=True) +class Args: + data_dirs: list[str] + data_mix_weights: list[float] + + max_seeds_to_collect: int = field(default=100000000) + continue_from: str | None = field(default=None) + + # Keep the following arguments unchanged for reproducibility + seed: int = field(default=976) + + min_lines: int = field(default=5) + max_lines: int = field(default=30) + min_doc_lines: int = field(default=10) + max_doc_lines: int = field(default=5000) + max_avg_chars_per_line: int = field(default=80) + # max_fragments: int = field(default=3) + chunk_size: int = field(default=1000) + # A small value lets one document be used by multiple seeds + content_chunk_lines: int = field(default=100) + + dataset_name: str = field(default="bigcode/starcoderdata") + data_files: list[str] | None = field(default=None) + max_considered_data: int | None = field(default=500000000) + + collect_function: bool = field(default=False) + max_nodes_to_traverse: int = field(default=20000) + + tag: str = field( + default="", + metadata={ + "help": "Custom tag as part of the output filename, not affecting the fingerprint" + }, + ) + text_wrap: int | None = field(default=None) + + std_lib_only: bool = field(default=False) + min_stars: int = field(default=0) + n_cores: int = field(default=star_align.utils.N_CORES) + + def fingerprint(self) -> str: + # The combination of arguments can uniquely determine the generation process + args = [ + self.data_dirs, + self.data_mix_weights, + self.seed, + self.min_lines, + self.max_lines, + self.min_doc_lines, + self.max_doc_lines, + self.max_avg_chars_per_line, + self.chunk_size, + self.dataset_name, + self.max_considered_data, + self.content_chunk_lines, + self.min_stars, + self.text_wrap, + self.data_files, + self.collect_function, + self.max_nodes_to_traverse, + ] + # for backward compatibility, only add if needed + # if self.text_wrap is not None: + # args.append(self.text_wrap) + # if self.data_files is not None: + # args.append(self.data_files) + return star_align.utils.compute_fingerprint(*args, hash_length=5) + + +# def fragments_to_text(fragments: list[str]) -> str: +# return "...\n".join(fragments) + + +ASTTypeMap = { + "cpp": [ + "function_definition" + ], # https://github.com/tree-sitter/tree-sitter-cpp/blob/master/grammar.js + # "csharp": [ + # "method_declaration" + # ], # https://github.com/tree-sitter/tree-sitter-c-sharp/blob/master/grammar.js + "java": [ + "method_declaration" + ], # https://github.com/tree-sitter/tree-sitter-java/blob/master/grammar.js + "php": [ + "method_declaration" + ], # https://github.com/tree-sitter/tree-sitter-php/blob/master/grammar.js + "python": [ + "function_definition" + ], # https://github.com/tree-sitter/tree-sitter-python/blob/master/grammar.js + "rust": [ + "function_item" + ], # https://github.com/tree-sitter/tree-sitter-rust/blob/master/grammar.js + # "swift": [ + # "function_declaration" + # ], # https://github.com/alex-pinkus/tree-sitter-swift/blob/main/grammar.js + "typescript": [ + "function_declaration", + "method_definition", + ], # https://github.com/tree-sitter/tree-sitter-typescript/blob/master/typescript/grammar.js + # "bash": [ + # "function_definition" + # ], # https://github.com/tree-sitter/tree-sitter-bash/blob/master/grammar.js +} + +EXCLUDED_PATTERNS = { + "python": { + "function_definition": [ + "def __[a-zA-Z0-9_]*__\\(", # Broadly exclude magic methods, but note this may also exclude some user-defined dunder methods which might be 'interesting' + "def test[a-zA-Z0-9_]*\\(", # Exclude test functions + # ends with pass + # "\\(self[,\\)]", + "pass$", + ] + }, + "cpp": { + "function_definition": [ + "~[a-zA-Z0-9_]*\\(", # Exclude C++ destructors + ] + }, + "java": { + "function_definition": [ + "[gs]et[A-Z][0-9a-zA-Z_]*\\(", + ] + }, + "php": { + "function_definition": [ + "function __[A-Za-z0-9_]*", # Exclude PHP magic methods + "function test[a-zA-Z0-9_]*\\(", # Exclude PHP magic methods + ] + }, + "rust": { + "function_definition": [ + "fn new(", + "fn default(", + "fn test[a-zA-Z0-9_]*\\(", + ] + }, + "typescript": { + "function_definition": [ + "constructor(", + ] + }, +} + + +def extract_snippets_with_constraints( + tree, + source_code: str, + considered_types: list[str], + min_lines: int, + max_lines: int, + max_nodes_to_traverse: int, + excluded_patterns: dict[str, list[str]], +) -> list[str]: + """ + Extract snippets from the source code that match the given constraints. + """ + from tree_sitter import Node, Tree + assert isinstance(tree, Tree) + + matching_snippets = list[str]() + # source_code = root_node.text.decode() + + # https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/walk_tree.py + def traverse_tree(tree: Tree) -> Generator[Node, None, None]: + cursor = tree.walk() + + visited_children = False + while True: + if not visited_children: + yield cursor.node + if not cursor.goto_first_child(): + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent(): + break + + all_nodes_iter = traverse_tree(tree) + for node, _ in zip(all_nodes_iter, range(max_nodes_to_traverse)): + if node.type in considered_types: + start_line = node.start_point[0] # line numbers are zero-indexed + end_line = node.end_point[0] + line_count = end_line - start_line + 1 + + if min_lines <= line_count <= max_lines: + snippet = extract_code_with_indentation(node, source_code) + # Check if the fragment contains any of the excluded keywords + if all( + re.search(pattern, snippet) is None + for pattern in excluded_patterns.get(node.type, []) + ): + snippet = textwrap.dedent(snippet) + matching_snippets.append(snippet) + + return matching_snippets + + +def extract_code_with_indentation(node, source_code: str) -> str: + """ + Extract the source code corresponding to a given node from the original source code string, + including the indentation based on the node's starting line. + """ + start_byte = node.start_byte + end_byte = node.end_byte + # Find the newline character before the node starts, to determine the start of the line + start_of_line = ( + source_code.rfind("\n", 0, start_byte) + 1 + ) # +1 to move past the newline character + # Calculate the indentation by counting whitespace characters from the start of the line to the node start + indentation = "" + for i in range(start_of_line, start_byte): + if source_code[i].isspace(): + indentation += source_code[i] + else: + break + # Extract the code and prepend the indentation to each line + code_fragment = source_code[start_byte:end_byte] + indented_code_fragment = indentation + code_fragment + return indented_code_fragment + + +def chunk_content(examples: dict, indices: list[int], args: Args) -> dict: + contents = examples["content"] + examples["raw_index"] = indices + + def chunk(content: str) -> list[str]: + lines = content.splitlines(keepends=True) + chunks = list[str]() + for end_idx in range(len(lines), 0, -args.content_chunk_lines): + # throw away the last chunk if it's too small + if end_idx < args.content_chunk_lines and len(chunks) > 0: + break + chunks.append("".join(lines[end_idx - args.content_chunk_lines : end_idx])) + return chunks + + new_data: dict = dict() + for index in range(len(contents)): + content = contents[index] + chunked_contents = chunk(content) + new_data.setdefault("chunked_content", []).extend(chunked_contents) + for key in examples.keys(): + new_others = [examples[key][index]] * len(chunked_contents) + new_data.setdefault(key, []).extend(new_others) + + return new_data + + +def map_dataset(examples: dict, indices: list[int], args: Args, data_dir: str) -> dict: + random.seed(args.seed + sum(map(ord, data_dir)) + indices[0]) + stars = list(map(int, examples["max_stars_count"])) + content_key = "chunked_content" if not args.collect_function else "content" + contents = list(map(sanitize_document, examples[content_key])) + assert len(contents) == len(stars) + if not args.collect_function: + seed_fragments = [ + ( + extract_fragment(args, content) + if content is not None and star >= args.min_stars + else None + ) + for star, content in zip(stars, contents) + ] + # seed = [ + # (fragments_to_text(fragments) if fragments is not None else None) + # for fragments in seed_fragments + # ] + assert len(seed_fragments) == len(indices) + return { + "seed": seed_fragments, + "repo": examples["max_stars_repo_name"], + "star": stars, + "id": list(map(int, indices)), + "raw_index": examples["raw_index"], + "data_dir": [data_dir for _ in indices], + # "content": examples["content"] + } + from tree_sitter_languages import get_parser + + parser = get_parser(data_dir) + repos = examples["max_stars_repo_name"] + ids = examples["id"] + raw_indices = indices + data_dirs = [data_dir for _ in indices] + data: dict[str, list] = { + "seed": [], + "repo": [], + "star": [], + "id": [], + "raw_index": [], + "data_dir": [], + } + for repo, star, id, raw_index, content, data_dir in zip( + repos, stars, ids, raw_indices, contents, data_dirs + ): + if ( + content is None + or star < args.min_stars + or (n_lines := content.count("\n")) < args.min_doc_lines + or n_lines > args.max_doc_lines + or len(content) > args.max_avg_chars_per_line * args.max_doc_lines + ): + continue + try: + content_encoded = content.encode() + if len(content_encoded) != len(content): + # probably Non-english + continue + tree = parser.parse(content_encoded) + root_node = tree.root_node + if root_node.has_error: + continue + # if len(root_node.text) != len(content): + # # probably Non-english + # continue + snippets = extract_snippets_with_constraints( + tree=tree, + source_code=content, + considered_types=ASTTypeMap[data_dir], + min_lines=args.min_lines, + max_lines=args.max_lines, + max_nodes_to_traverse=args.max_nodes_to_traverse, + excluded_patterns=EXCLUDED_PATTERNS[data_dir], + ) + data["seed"].extend(snippets) + for key in ["repo", "star", "id", "raw_index", "data_dir"]: + data[key].extend([locals()[key]] * len(snippets)) + except UnicodeError: + pass + return data + + +# def uniform_partition(n: int, k: int) -> list[int]: +# """Partition n into k non-negative integers. Stars and bars method. +# x1 + x2 + ... + xk = n; xi >= 0. Can be transformed to positive case: +# y1 + y2 + ... + yk = n - k; yi = xi + 1 > 0""" +# assert n >= 0, "n should be >= 0" +# assert k > 0, " should be > 0" +# random_numbers = [random.randint(0, n) for _ in range(k - 1)] +# values = [0] + sorted(random_numbers) + [n] +# intervals = [values[i + 1] - values[i] for i in range(len(values) - 1)] +# assert sum(intervals) == n +# assert len(intervals) == k +# return intervals + + +# def uniform_partition_positive(n: int, k: int) -> list[int]: +# return [x + 1 for x in uniform_partition(n - k, k)] + + +# def is_en(content: str, seed: int) -> bool: +# import langdetect + +# langdetect.DetectorFactory.seed = seed +# try: +# return langdetect.detect(content) == "en" +# except langdetect.lang_detect_exception.LangDetectException: +# return False + + +# def place_blocks(n: int, sizes: list[int]) -> list[int]: +# """Randomly place k blocks of sizes `sizes` in a line of length n. Return the starting positions.""" +# assert n >= 0, "n should be >= 0" +# k = len(sizes) +# assert k > 0, "k should be > 0" +# assert sum(sizes) <= n, "sum(sizes) should be <= n" +# if k == 1: +# return [random.randint(0, n - sizes[0])] +# all_but_last_pos = place_blocks(n - sizes[-1], sizes[:-1]) +# last_pos = random.randint(all_but_last_pos[-1] + sizes[-2], n - sizes[-1]) +# result = all_but_last_pos + [last_pos] +# assert len(result) == k +# for i in range(k - 1): +# assert result[i] + sizes[i] <= result[i + 1] +# return result + + +def sanitize_document(document: str) -> str | None: + """Sanitize the document by removing the first line if it's a placeholder.""" + if ( + document.startswith("") + or document.startswith("") + or document.startswith("") + ): + # remove the first line + newline_index = document.find("\n") + if newline_index == -1: + return None + document = document[newline_index + 1 :] + return document + + +def extract_fragment(args: Args, document: str) -> str | None: + if args.std_lib_only: + if not check_std_libs_only(document): + return None + else: + return document + if args.text_wrap is not None: + document = textwrap.fill( + document, + width=args.text_wrap, + replace_whitespace=False, + expand_tabs=False, + break_on_hyphens=False, + drop_whitespace=False, + break_long_words=False, + ) + # if args.data_dir == "markdown" and not is_en(document, args.seed): + # return None + document = document.replace("\r", "") + document = re.sub(r"\n\n+", "\n\n", document) + lines = document.splitlines(keepends=True) + + # special rules + # if args.data_dir == "jupyter-scripts-dedup-filtered": + # lines = [ + # line + # for line in lines + # if "jupyter" not in line.lower() and "jupytext" not in line.lower() + # ] + # elif args.data_dir == "markdown": + # lines = [ + # line + # for line in lines + # if "http:" not in line and "https:" not in line and "www." not in line + # ] + + # lines = [line for line in lines if line.strip() != ""] + + # if len(lines) < args.min_lines or len(lines) == 0: + if len(lines) < args.min_doc_lines or len(lines) > args.max_doc_lines: + return None + # avg chars + if len(document) > args.max_avg_chars_per_line * args.max_doc_lines: + return None + max_lines = min(args.max_lines, len(lines)) + assert args.max_lines >= args.min_lines + n_lines_to_consider = random.randint(args.min_lines, max_lines) + # max_fragments = min(n_lines_to_consider, args.max_fragments) + # n_fragments = random.randint(1, max_fragments) + # fragment_sizes = uniform_partition_positive(n_lines_to_consider, n_fragments) + # fragment_indices = place_blocks(len(lines), fragment_sizes) + # fragments = [ + # "".join(lines[i : i + size]) + # for i, size in zip(fragment_indices, fragment_sizes) + # ] + start_index = random.randint(0, len(lines) - n_lines_to_consider) + # random.shuffle(fragments) + content = "".join(lines[start_index : start_index + n_lines_to_consider]) + content = textwrap.dedent(content.replace("\t", " ")) + return content + + +import sys + + +def is_std_lib(name): + """Check if a module is a standard library.""" + return name in {*sys.builtin_module_names, *sys.stdlib_module_names} + + +def check_std_libs_only(code): + """Check if all imported libraries in the given code are standard libraries.""" + lines = code.split("\n") + for line in lines: + if line.startswith("import ") or line.startswith("from "): + parts = line.split() + if len(parts) < 2: + # special case + return False + if line.startswith("import "): + module_name = parts[1].split(".")[0] # Get the base module name + else: # from ... import ... + module_name = parts[1] + + if not is_std_lib(module_name): + return False + return True + + +def main(): + args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + assert len(args.data_dirs) == len(args.data_mix_weights) + sum_weights = sum(args.data_mix_weights) + data_mix_ratios = [w / sum_weights for w in args.data_mix_weights] + random.seed(args.seed) + raw_datasets: list[Dataset] = [] + num_proc = args.n_cores + # num_proc = 1 + for data_dir, ratio in zip(args.data_dirs, data_mix_ratios): + max_considered_data = ( + None + if args.max_considered_data is None + else int(args.max_considered_data * ratio) + ) + print(f"Loading {data_dir} with max_considered_data={max_considered_data}") + split = ( + f"train[:{max_considered_data}]" + if max_considered_data is not None + else "train" + ) + try: + kwargs = dict( + data_dir=data_dir, + split=split, + data_files=args.data_files, + num_proc=num_proc, + ignore_verifications=True, + ) + sub_dataset = load_dataset(args.dataset_name, **kwargs) + except ValueError: + print( + f"Failed to load {data_dir} with split=train[:{max_considered_data}]. Trying split=train" + ) + kwargs["split"] = "train" + sub_dataset = load_dataset(args.dataset_name, **kwargs) + if not args.collect_function: + sub_dataset = sub_dataset.map( + function=chunk_content, + fn_kwargs=dict(args=args), + with_indices=True, + batched=True, + batch_size=args.chunk_size, + num_proc=num_proc, + remove_columns=["content"], + ) + raw_datasets.append(sub_dataset) + # map_fn = get_map_dataset(args) + datasets: list[Dataset] = [] + for data_dir, sub_dataset in zip(args.data_dirs, raw_datasets): + sub_dataset = sub_dataset.map( + function=map_dataset, + fn_kwargs=dict(args=args, data_dir=data_dir), + with_indices=True, + batched=True, + num_proc=num_proc, + batch_size=args.chunk_size, + remove_columns=sub_dataset.column_names, + load_from_cache_file=False, + ) + datasets.append(sub_dataset) + dataset = concatenate_datasets(datasets) + dataset = dataset.shuffle(seed=args.seed) + + # Every run should produce the same data as long as the default params are not changed + data_fingerprint = args.fingerprint() + timestamp = star_align.utils.timestamp() + tag = "" if args.tag == "" else f"-{args.tag}" + path = Path(f"data-seed{tag}-{data_fingerprint}-{timestamp}.jsonl") + assert not path.exists() + with path.open("w") as f_out: + print("Saving to", path) + + n_success = 0 + all_seed_texts = set[str]() + + def get_seed_text(seed: str) -> str: + return "".join(seed.split()) + + pbar = tqdm(total=min(args.max_seeds_to_collect, len(dataset))) + for example in dataset: + if n_success >= args.max_seeds_to_collect: + break + if example["seed"] is None: + continue + seed_text = get_seed_text(example["seed"]) + # remove those with only symbols + if all(not c.isalpha() for c in seed_text): + # print("[filter(symbols Only)]", example["seed"], sep="\n") + continue + if seed_text in all_seed_texts: + # print("[filter(duplicate)]", example["seed"], sep="\n") + continue + all_seed_texts.add(seed_text) + data = example + # print("[Seed]", example["seed"], sep="\n") + f_out.write(json.dumps(data) + "\n") + n_success += 1 + pbar.update(1) + print("Success:", n_success) + + +if __name__ == "__main__": + main() diff --git a/src/star_align/decontamination/benchmark_data.py b/src/star_align/decontamination/benchmark_data.py new file mode 100644 index 0000000..15c6762 --- /dev/null +++ b/src/star_align/decontamination/benchmark_data.py @@ -0,0 +1,189 @@ +"""Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" + +"""data to filter out of the dataset""" +import itertools +import json +import os +from pathlib import Path + +from datasets import load_dataset + +# HumanEval solutions that are considered simple/generic enough to be kept in the training dataset +HUMAN_EVAL_STRINGS_OK = [ + "return x + y", + "return len(string)", + "return n**2", + "return " ".join(strings)", +] + + +def extract_ds_1000_prompt(prompt: str): + if "SOLUTION START" in prompt: + assert prompt.count("SOLUTION START") == 1 + return prompt.split("SOLUTION START")[0] + elif "BEGIN SOLUTION" in prompt: + assert prompt.count("BEGIN SOLUTION") == 1 + return prompt.split("BEGIN SOLUTION")[0] + else: + raise ValueError() + + +def load_ds_1000(): + DS1000_PATH_NAME = os.getenv("DS1000_PATH", None) + assert ( + DS1000_PATH_NAME is not None + ), "Please set the environment variable DS1000_PATH to the path of `ds1000_data`" + DS1000_PATH = Path(DS1000_PATH_NAME) # type: ignore + data: dict = {} + for prompt_file in DS1000_PATH.glob("*/Insertion/q*/prompt.txt"): + with open(prompt_file) as f: + data[extract_ds_1000_prompt(f.read())] = prompt_file.as_posix() + return data + + +def load_mbpp(): + MBPP_PATH_NAME = os.getenv("MBPP_PATH", None) + assert ( + MBPP_PATH_NAME is not None + ), "Please set the environment variable MBPP_PATH to the path of `mbpp.jsonl`" + MBPP_PATH = Path(MBPP_PATH_NAME) + TEST_IDS = list(range(11, 511)) + data = [] + with open(MBPP_PATH) as f: + for line in f: + data.append(json.loads(line)) + + data = [sample for sample in data if sample["task_id"] in TEST_IDS] + + assert len(data) == 500 + + # Checksum / version issues here + # dataset = load_dataset("mbpp", split="test") + return data + + +def mbpp_docstrings(): + data = load_mbpp() + return {sample["text"]: str(sample["task_id"]) for sample in data} + + +def mbpp_solutions(): + data = load_mbpp() + return {sample["code"]: str(sample["task_id"]) for sample in data} + + +def extract_docstring(prompt: str) -> str: + if '"""' in prompt: + if prompt.count('"""') == 2: + return prompt.split('"""')[1].strip() + elif prompt.count('"""') == 4: + return prompt.split('"""')[3].strip() + else: + raise ValueError() + elif "'''" in prompt: + assert prompt.count("'''") == 2 + return prompt.split("'''")[1].strip() + else: + raise ValueError() + + +def human_eval_docstrings(): + ds = load_dataset("openai_humaneval", split="test") + docstrings = {extract_docstring(v["prompt"]): str(v["task_id"]) for v in ds} + return docstrings + + +def apps_solutions(): + """ + Solutions column contains a list of strings + """ + ds = load_dataset("codeparrot/apps", split="test") + solutions = [sample["solutions"] for sample in ds if len(sample["solutions"]) > 0] + res = itertools.chain.from_iterable(json.loads(sample) for sample in solutions) + return list(res) + + +def multipl_e_docstrings(): + languages = [ + "cpp", + "cs", + "d", + "go", + "java", + "jl", + "js", + "lua", + "php", + "pl", + "py", + "r", + "rb", + "rkt", + "rs", + "scala", + "sh", + "swift", + "ts", + ] + # languages = ["py", "java", "js"] + src_datas = ["humaneval", "mbpp"] + variations = ["", "-remove"] + data = [] + for lang in languages: + for src_data in src_datas: + for variation in variations: + if src_data == "mbpp" and variation == "-remove": + continue + ds = load_dataset( + "nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test" + ) + data += [sample["prompt"].strip() for sample in ds] + return data + + +def load_dataset_column(dataset: str, column: str, split: str, name=None): + ds = load_dataset(dataset, split=split, name=name) + # res = [sample[column].strip() for sample in ds] + # Only return non-empty strings + return { + sample_col_stripped: str( + sample["task_id"] if "task_id" in sample else f"{dataset}/{idx}" + ) + for idx, sample in enumerate(ds) + if len(sample_col_stripped := sample[column].strip()) > 0 + } + + +LAZY_FILTER_OUT = { + "human_eval_docstrings": lambda: human_eval_docstrings(), + "human_eval_solutions": lambda: { + s: v + for s, v in load_dataset_column( + "openai_humaneval", "canonical_solution", "test" + ).items() + if s not in HUMAN_EVAL_STRINGS_OK + }, + # "apps_docstrings": lambda: load_dataset_column( + # "codeparrot/apps", "question", "test" + # ), + # 115212 examples to filter-out in apps-solutions, which would take way too much time without any hashing trick + # "apps_solutions": apps_solutions(), + # MultiPL-E samples are from HumanEval and MBPP: we are already looking for them + # "multipl-e_docstrings": multipl_e_docstrings(), + # There is no solution provided with multipl-e + "gsm8k_questions": lambda: load_dataset_column("gsm8k", "question", "test", "main"), + "ds_1000_prompts": lambda: load_ds_1000(), + "mbpp_docstrings": lambda: mbpp_docstrings(), + "mbpp_solutions": lambda: mbpp_solutions(), +} + +IGNORED = os.getenv("IGNORED", "").split(":") +print("Ignoring:", IGNORED) +for ignored in IGNORED: + if ignored != "" and ignored in LAZY_FILTER_OUT: + del LAZY_FILTER_OUT[ignored] +FILTER_OUT = {k: v() for k, v in LAZY_FILTER_OUT.items()} + + +for benchmark, values in FILTER_OUT.items(): + print(f"num strings from {benchmark}: {len(values)}") diff --git a/src/star_align/decontamination/find_substrings.py b/src/star_align/decontamination/find_substrings.py new file mode 100644 index 0000000..2c1f2f6 --- /dev/null +++ b/src/star_align/decontamination/find_substrings.py @@ -0,0 +1,368 @@ +# type: ignore +"""Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" + +import argparse +import json +import os +import shutil +from copy import deepcopy +from glob import glob +from pathlib import Path + +from datasets import load_dataset + +from star_align.utils import write_jsonl + +from .benchmark_data import FILTER_OUT +from .utils import add_dict, shard_dataset + +SHARD_SIZE = 1000 << 20 # 1GB +LANGUAGE_COL = "lang" +# LANGUAGES = ["Python", "Java", "JavaScript"] + + +def dump_benchmarks(file_path: str): + """ + Dump the dictionary of benchmark samples that are filtered out + """ + with open(file_path, "w") as f: + json.dump(FILTER_OUT, f, indent=2) + + +def filter_reason_to_benchmark_name(filter_reason: str): + assert filter_reason.endswith("_match") + return filter_reason[:-6] + + +def benchmark_name_to_filter_reason(benchmark_name: str): + return f"{benchmark_name}_match" + + +def update_benchmark_dict( + filter_out: dict, benchmark_cache: str, excluded_data_cache: str +): + """ + Iterates on current benchmark-samples. If a sample is found in the cached benchmark-samples, it is removed (it does not need to be searched), + and the corresponding data-samples from the cache are added to `exclude_data` + + Returns: + - `updated`: an updated benchmark dict where samples from the cache are removed (they do not need to be searched anymore) + - `exclude_data`: a list of files to remove from the dataset + """ + updated = deepcopy(filter_out) + exclude_data = [] + with open(benchmark_cache) as f: + benchmark_cache = json.load(f) + with open(excluded_data_cache) as f: + excluded_data_cache = json.load(f) + + for bench, samples in filter_out.items(): + for bench_sample in samples: + # Benchmark-sample was found in cache + if bench in benchmark_cache and bench_sample in benchmark_cache[bench]: + # No need to search for this sample in the dataset + updated[bench].remove(bench_sample) + # Corresponding data-samples will be excluded from the dataset. + exclude_data += [ + data_sample + for data_sample in excluded_data_cache + if data_sample["filter_reason"] + == benchmark_name_to_filter_reason(bench) + and data_sample["matched_substring"] == bench_sample + ] + + print("After loading cache, will search for:") + for benchmark, values in updated.items(): + print(f" num strings from {benchmark}: {len(values)}") + # Remove empty benchmarks + updated = {key: value for key, value in updated.items() if len(value) > 0} + return updated, exclude_data + + +def find_substrings(data, columns, filter_out, return_matched=False): + """ + filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be + filtered-out. + Return True, None if the file should be included in the dataset. + Otherwise return False and some metadata about the file excluded + """ + content = "\n\n".join([data[col].lower() for col in columns]) + # For each substring, try to find it in the file (case insensitive) + for benchmark, substrings in filter_out.items(): + for substring in substrings: + if substring.lower() in content: + if return_matched: + return False, benchmark_name_to_filter_reason(benchmark), substring + else: + return False, benchmark_name_to_filter_reason(benchmark) + + # Return True, None if none of the substrings was found + if return_matched: + return True, None, None + else: + return True, None + + +def aggregate_meta(tmp_meta_dir: str): + res = {} + for file in glob(f"{tmp_meta_dir}/*-meta.json"): + with open(file, "r") as f: + meta = json.load(f) + add_dict(res, meta) + return res + + +def concatenate_meta(tmp_meta_dir: str): + res = [] + for file in glob(f"{tmp_meta_dir}/*-excluded-data.json"): + with open(file, "r") as f: + meta = json.load(f) + res += meta + return res + + +class Meta: + def __init__(self) -> None: + self.meta_dict = dict() + + def update(self, lang: str, filter_reason: str): + if lang not in self.meta_dict: + self.meta_dict[lang] = {} + if filter_reason not in self.meta_dict[lang]: + self.meta_dict[lang][filter_reason] = 0 + self.meta_dict[lang][filter_reason] += 1 + + +class SubstringFilterer(object): + def __init__( + self, + output_dir: str, + output_file: str, + cached_decontamination_dir: str, + split_languages: bool, + cache_retrieval_key: str, + columns: list[str], + tmp_meta_dir=None, + data_dir=None, + ) -> None: + self.output_dir = output_dir + self.output_file = output_file + self.split_languages = split_languages + self.cache_retrieval_key = cache_retrieval_key + self.columns = columns + self.tmp_meta_dir = ( + tmp_meta_dir if tmp_meta_dir is not None else f"{output_dir}/tmp/meta" + ) + self.data_dir = data_dir if data_dir is not None else f"{output_dir}/data" + os.makedirs(self.tmp_meta_dir, exist_ok=True) + os.makedirs(self.data_dir, exist_ok=True) + # Save benchmark data + self.excluded_data_cache = os.path.join(self.output_dir, "excluded-data.json") + self.benchmarks_cache = os.path.join(output_dir, "benchmarks.json") + dump_benchmarks(self.benchmarks_cache) + + if cached_decontamination_dir is not None: + # Load cache + self.filter_out, self.exclude_data = update_benchmark_dict( + FILTER_OUT, + os.path.join(cached_decontamination_dir, "benchmarks.json"), + os.path.join(cached_decontamination_dir, "excluded-data.json"), + ) + # All hashes should be unique + hash_list = [ + data_sample["data"][self.cache_retrieval_key] + for data_sample in self.exclude_data + ] + assert len(hash_list) == len(set(hash_list)) + # dict: retrieval-key (hash/content) -> data-sample + self.exclude_data_index = { + data_sample["data"][self.cache_retrieval_key]: data_sample + for data_sample in self.exclude_data + } + self.use_cached_decontamination = True + else: + self.filter_out = FILTER_OUT + self.exclude_data = None + self.exclude_data_index = {} + self.use_cached_decontamination = False + + def _filter_file(self, sample): + should_include, filter_reason, matched_substring = True, None, None + if self.use_cached_decontamination: + # According to cache, this data sample should be excluded + if sample[self.cache_retrieval_key] in self.exclude_data_index: + should_include = False + filter_reason = self.exclude_data_index[ + sample[self.cache_retrieval_key] + ]["filter_reason"] + matched_substring = self.exclude_data_index[ + sample[self.cache_retrieval_key] + ]["matched_substring"] + # If sample has passed the cache, check the other substrings + if should_include: + should_include, filter_reason, matched_substring = find_substrings( + sample, self.columns, self.filter_out, return_matched=True + ) + return should_include, filter_reason, matched_substring + + def _filter(self, batch: dict, idx): + meta = Meta() + excluded_data = [] + features = batch.keys() + res = {k: [] for k in features} + for sample in zip(*[batch[k] for k in features]): + sample = {k: v for k, v in zip(features, sample)} + should_include, filter_reason, matched_substring = self._filter_file(sample) + if not should_include: + meta.update(sample.get(LANGUAGE_COL, "unknown"), filter_reason) + excluded_data.append( + { + "data": sample, + "filter_reason": filter_reason, + "matched_substring": matched_substring, + } + ) + else: + # Add to output + for k in features: + res[k].append(sample[k]) + + # Record Meta + with open( + os.path.join(self.tmp_meta_dir, f"{idx[0]}-{idx[-1]}-meta.json"), "w" + ) as f: + json.dump(meta.meta_dict, f) + with open( + os.path.join(self.tmp_meta_dir, f"{idx[0]}-{idx[-1]}-excluded-data.json"), + "w", + ) as f: + json.dump(excluded_data, f, indent=2) + return res + + def filter_dataset(self, ds, num_proc, batch_size): + filtered = ds.map( + self._filter, + batched=True, + batch_size=batch_size, + with_indices=True, + num_proc=num_proc, + load_from_cache_file=False, + ) + print("Number of samples in the new dataset: ", len(filtered)) + return filtered + + def finalize(self): + # Dump meta + meta = aggregate_meta(self.tmp_meta_dir) + print(meta) + with open(os.path.join(self.output_dir, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + # Dump excluded-data.json + meta = concatenate_meta(self.tmp_meta_dir) + print("Number of excluded examples: ", len(meta)) + with open(self.excluded_data_cache, "w") as f: + json.dump(meta, f, indent=2) + # delete temporary meta data + shutil.rmtree(self.tmp_meta_dir) + + # def save(self, filtered, num_proc): + # # Save shards + # if self.split_languages: + # for lang in LANGUAGES: + # print(f"Sharding subset: {lang}") + # target_dir = os.path.join(self.data_dir, lang.lower()) + # os.makedirs(target_dir, exist_ok=True) + # subset = filtered.filter(lambda example: example[LANGUAGE_COL] == lang, num_proc=num_proc) + # shard_dataset(subset, SHARD_SIZE, target_dir, num_proc=16) + # else: + # shard_dataset(filtered, SHARD_SIZE, self.data_dir, num_proc=16) + + def run(self, dataset, num_proc, batch_size): + filtered = self.filter_dataset(dataset, num_proc, batch_size) + write_jsonl(Path(self.output_file), filtered) + # Finalize meta-data + self.finalize() + # Save filtered dataset. + # NOTE: we save to jsonl so this is not needed + # self.save(filtered, num_proc) + return filtered + + +def arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset_name", + default="json", + type=str, + help="Name or path of the HF dataset to decontaminate", + ) + parser.add_argument("--data_files", nargs="+", default=None, help="Data files") + parser.add_argument( + "--columns", + nargs="+", + required=True, + help="Columns to form the text to search for", + ) + parser.add_argument( + "--output_file", required=True, type=str, help="Path to save output jsonl data" + ) + parser.add_argument( + "--output_dir", + required=True, + type=str, + help="Path to save output data and metadata", + ) + parser.add_argument("--num_proc", type=int, default=200, help="Number of processes") + parser.add_argument( + "--batch_size", + type=int, + default=10000, + help="Size of batches passed to Dataset.map", + ) + parser.add_argument( + "--cached_decontamination_dir", + type=str, + default=None, + help="Directory containing a `benchmarks.json` and `excluded_data.json` files from a previous decontamination run." + "Will use this data to avoid searching again for strings that were previously decontaminated." + "It's up to the user to ensure that the dataset being decontaminated is a subset of the one from the cached decontamination run" + "(Otherwise not all the benchmark samples will be checked against new data samples)", + ) + parser.add_argument( + "--cache_retrieval_key", + type=str, + default="hexsha", + help="Key used to retrieve examples from the cache. Ideally `hexsha`. Otherwise, another unique feature in case the hash is not present, like `content`)", + ) + parser.add_argument( + "--split_languages", + action="store_true", + help="If True, will create one subfolder per language for the output dataset.", + ) + return parser.parse_args() + + +def main(): + args = arguments() + + filterer = SubstringFilterer( + output_dir=args.output_dir, + output_file=args.output_file, + columns=args.columns, + cached_decontamination_dir=args.cached_decontamination_dir, + split_languages=args.split_languages, + cache_retrieval_key=args.cache_retrieval_key, + ) + + ds = load_dataset( + args.dataset_name, + split="train", + data_files=args.data_files, + # chunksize=40 << 20 + ) + + filterer.run(ds, args.num_proc, args.batch_size) + + +if __name__ == "__main__": + main() diff --git a/src/star_align/decontamination/utils.py b/src/star_align/decontamination/utils.py new file mode 100644 index 0000000..b1adbe5 --- /dev/null +++ b/src/star_align/decontamination/utils.py @@ -0,0 +1,62 @@ +"""Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" + +import time +from multiprocessing import Pool + +from tqdm import tqdm + + +def save_shard(shard_tuple): + """Save shard""" + filename, shard = shard_tuple + # use to_json instead to save as json file + shard.to_parquet(filename) + + +def shard_dataset(ds, shard_size, output_dir, num_proc): + if ds._indices is not None: + dataset_nbytes = ds.data.nbytes * len(ds._indices) / len(ds.data) + else: + dataset_nbytes = ds.data.nbytes + num_shards = int(dataset_nbytes / shard_size) + 1 + print(f"Number of shards: {num_shards}") + + print("sharding the dataset") + t_start = time.time() + shards = ( + ds.shard(num_shards=num_shards, index=i, contiguous=True) + for i in range(num_shards) + ) + # use f"{OUT_PATH}/data/train-{index:05d}-of-{num_shards:05d}.json" instead for json files + filenames = ( + f"{output_dir}/train-{index:05d}-of-{num_shards:05d}.parquet" + for index in range(num_shards) + ) + + with Pool(num_proc) as p: + list( + tqdm( + p.imap_unordered(save_shard, zip(filenames, shards), chunksize=4), + total=num_shards, + ) + ) + print(f"Time to save dataset: {time.time()-t_start:.2f}") + + +def add_dict(dict1: dict, dict2: dict) -> None: + """ + Add the values of dict2 to dict1. All values must be int, float or dictionaries that also verify this condition. + Will modify dict1 and return None + """ + for key, value in dict2.items(): + if isinstance(value, (int, float)): + if key not in dict1: + dict1[key] = 0 + dict1[key] += value + elif isinstance(value, dict): + if key not in dict1: + dict1[key] = {} + assert isinstance(dict1[key], dict) + add_dict(dict1[key], value) + else: + raise ValueError(f"Invalid type for key/value {key}: {value}") diff --git a/src/star_align/execution_filter.py b/src/star_align/execution_filter.py index 77d4bc1..28be9a8 100644 --- a/src/star_align/execution_filter.py +++ b/src/star_align/execution_filter.py @@ -1,24 +1,59 @@ import ast import json import os -import re import shutil import sys from concurrent.futures import ProcessPoolExecutor, as_completed from multiprocessing import Process, cpu_count - from evalplus.eval.utils import ( create_tempdir, reliability_guard, swallow_io, time_limit, ) -from fire import Fire from tqdm.auto import tqdm -from star_align.utils import chunked +from datasets import load_dataset +from star_align.utils import chunked, find_code_blocks +from transformers import HfArgumentParser +from dataclasses import dataclass, field +from typing import cast + + +_magic_splitter_ = "### -- what do you think? -- ###" + + +def make_python_membound_code_prefix(limit_mb): + maximum_memory_bytes = limit_mb * 1024 * 1024 + return f"""\ +import resource +import platform -# dependencies: evalplus fire tqdm +resource.setrlimit( + resource.RLIMIT_AS, ({maximum_memory_bytes}, {maximum_memory_bytes}) +) +resource.setrlimit( + resource.RLIMIT_DATA, ({maximum_memory_bytes}, {maximum_memory_bytes}) +) +if not platform.uname().system == "Darwin": + resource.setrlimit( + resource.RLIMIT_STACK, ({maximum_memory_bytes}, {maximum_memory_bytes}) + ) +{_magic_splitter_} +""" + + +@dataclass(frozen=True) +class Args: + response_paths: list[str] + result_path: str + save_request_errors: bool = False + shuffle: bool = field(default=True) + cache_paths: list[str] = field(default_factory=list) + load_pass_only_cache: bool = field(default=False) + max_batched_tasks: int = 10000 + max_workers: int = cpu_count() + container_server: str | None = None def suppress_output(func): @@ -66,23 +101,26 @@ def _run(code) -> None: os.getcwd = getcwd -# from code_exec_server.code_exec_reqs import exec_test -def containerized_run(item): - from code_exec_server.code_exec_reqs import exec_test +def containerized_run(item, limit_mb=4 * 1024): + from star_align.code_exec_server.code_exec_reqs import exec_test idx, result, code, srv = item - passed, _ = exec_test(srv, code, "", timeout=10) - return (idx, result) if passed else None + membound_code = make_python_membound_code_prefix(limit_mb) + code + passed, output = exec_test( + srv, membound_code, "", timeout=10, timeout_on_client=True + ) + return (idx, result, code, passed, output) def fork_run(item): - idx, result, code, _ = item + idx, response, code, _ = item sys.stdout = open(os.devnull, "w") sys.stderr = sys.stdout p = Process(target=_run, args=(code,)) p.start() p.join(timeout=10) - return (idx, result) if p.exitcode == 0 else None + passed = p.exitcode == 0 + return (idx, response, code, passed, "NOT SUPPORTED") def is_compilable(code): @@ -93,114 +131,162 @@ def is_compilable(code): return False -def extract_code(response): - pattern = r"^```python\s*\n(.*?)(?=^```)" - result = re.findall(pattern, response, re.DOTALL | re.MULTILINE) - return "\n".join([x for x in result if is_compilable(x)]) - - -# /scratch/sc2-instruct/data-ossx-del2-fewshot-mpt-response-temp0-s_i-1shot-temp0-i_r-8754b-0-20240321_172151.jsonl -def main( - response_path: str, - result_path: str, - # NOTE: the higher the faster, but less reliable. 5000 is a good default. - max_batched_tasks: int = cpu_count(), - max_workers: int = cpu_count(), - cache_path: str | None = None, - container_server=None, -): - # load jsonl - with open(response_path, "r") as f: - raw_data = [json.loads(line) for line in f if line.strip()] - if cache_path is not None: - with open(cache_path, "r") as f: - cached_data = [json.loads(line) for line in f if line.strip()] - # instruction -> set[response] - hit_code = set[str]() - for item in tqdm(cached_data): - code = extract_code(item["response"]) - hit_code.add(code) - - uncompilable = 0 - all_tasks = [] - - print("Container server:", container_server) - - for idx, item in enumerate(tqdm(raw_data)): - # passing_results = [] - if "parsing_result" not in item: - code = extract_code(item["response"]) - if not code: - uncompilable += 1 - continue - all_tasks.append((idx, item, code, container_server)) - else: - for result in item["parsing_result"]: - code = extract_code(result["response"]) - if not code: - uncompilable += 1 - continue - all_tasks.append((idx, result, code, container_server)) - # passing_results.append((result, code)) - - # Split cached/un-cached data - active_tasks = [] - cached_tasks = [] - for task in tqdm(all_tasks, desc="Preprocessing: flattening tasks"): - _, _, code, _ = task - if cache_path is not None and code in hit_code: - cached_tasks.append(task) - else: - active_tasks.append(task) - - with open(result_path, "w") as f: - for idx, result, _, _ in cached_tasks: - newdata = { - k: v - for k, v in raw_data[idx].items() - if k not in ["response", "parsing_result"] +def extract_code(response: str) -> str: + def sanitize_codeblock(code: str) -> str: + if "input" not in code: + return code.strip() + # Only remove the `if __name__..` when `input` is present because + # it will block the code execution. + key = 'if __name__ == "__main__":' + key_alt = "if __name__ == '__main__':" + index = code.find(key) + if index == -1: + index = code.find(key_alt) + if index == -1: + return code.strip() + assert index != -1 + code = code[:index].strip() + return code + + code_blocks = list(map(sanitize_codeblock, find_code_blocks(response))) + return "\n\n".join(code_blocks) + + +def form_new_data( + item: dict, + response: str, + extracted_code: str, + pass_execution: bool, + output: str, +) -> dict: + newdata = {k: v for k, v in item.items() if k not in ["response", "parsing_result"]} + newdata["response"] = response + newdata["extracted_code"] = extracted_code + newdata["pass"] = pass_execution + newdata["output"] = output + return newdata + + +def main(): + args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + if args.container_server is None: + option = input( + "WARNING: container_server is not set. You will run the code locally, which can lead to unexpected side effects. Continue? (y/n): " + ).strip() + if option.lower() != "y": + return + + if os.path.exists(args.result_path): + option = input( + f"WARNING: {args.result_path} already exists. Overwrite? (y/n): " + ).strip() + if option.lower() != "y": + return + + cleanup_command = os.getenv("CLEANUP_COMMAND", None) + if cleanup_command is not None: + print(f"NOTE: the cleanup command is set to:") + print(cleanup_command) + + raw_data = load_dataset("json", data_files=args.response_paths, split="train") + if args.shuffle: + raw_data = raw_data.shuffle() + if len(args.cache_paths) > 0: + cached_data = load_dataset("json", data_files=args.cache_paths, split="train") + if args.load_pass_only_cache: + cached_dict: dict[str, dict] = { + item["extracted_code"]: item for item in cached_data if item["pass"] } - newdata["response"] = result["response"] - f.write(json.dumps(newdata) + "\n") - - print(f"Active tasks: {len(active_tasks)}") - print(f"Cached tasks: {len(cached_tasks)}") - - run_func = containerized_run if container_server else fork_run + else: + cached_dict = {item["extracted_code"]: item for item in cached_data} + else: + cached_dict = {} - nfails = 0 - tasks_chunks = chunked(active_tasks, max_batched_tasks) - with open(result_path, "a") as f: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for chunked_tasks in tqdm(tasks_chunks): + all_tasks: list[tuple[int, str, str, str | None]] = [] + eval_results: list[dict] = [] + for idx, item in enumerate(tqdm(raw_data, desc="Preprocessing: extracting code")): + # passing_results = [] + if "parsing_result" not in item: + item["parsing_result"] = [dict(response=item["response"])] + for result in item["parsing_result"]: + response = result["response"] + code = extract_code(response) + if (hit_item := cached_dict.get(code, None)) is not None: + assert code == hit_item["extracted_code"] + new_data = form_new_data( + item=item, + response=response, + extracted_code=code, + pass_execution=hit_item["pass"], + output=hit_item["output"], + ) + eval_results.append(new_data) + else: + all_tasks.append((idx, response, code, args.container_server)) + + def pass_rate_str(passed: int, total: int, tag: str = "") -> str: + percentage = f"{passed/total * 100:.2f}%" if total > 0 else "N/A" + ratio = f"{passed}/{total}" + tag = f"{tag} " if len(tag) > 0 else "" + return f"{tag}Passed: {ratio} ({percentage})" + + n_cached_passed = sum(item["pass"] for item in eval_results) + n_cached_total = len(eval_results) + + print(f"Cached: {len(eval_results)}, Active: {len(all_tasks)}") + print(pass_rate_str(n_cached_passed, n_cached_total, "Cached")) + + run_func = containerized_run if args.container_server else fork_run + tasks_chunks = list(chunked(all_tasks, args.max_batched_tasks)) + n_processed = 0 + n_passed = 0 + with open(args.result_path, "w") as f: + for cached_result in eval_results: + f.write(json.dumps(cached_result) + "\n") + with ProcessPoolExecutor(max_workers=args.max_workers) as executor: + pbar = tqdm(tasks_chunks) + for chunked_tasks in pbar: futures = [executor.submit(run_func, task) for task in chunked_tasks] # NOTE: futures do not return in the same order as before - for future in tqdm(as_completed(futures), total=len(futures), leave=False): + pbar_inner = tqdm( + as_completed(futures), + total=len(futures), + leave=False, + ) + n_passed_inner = 0 + for n_processed_inner, future in enumerate(pbar_inner, start=1): + n_processed += 1 try: future_result = future.result() - if future_result is None: - nfails += 1 - continue - idx, result = future_result - newdata = { - k: v - for k, v in raw_data[idx].items() - if k not in ["response", "parsing_result"] - } - newdata["response"] = result["response"] - f.write(json.dumps(newdata) + "\n") - except Exception: - nfails += 1 + except Exception as e: continue - # if passed_indices: - # item = data[idx] - # item["parsing_result"] = [presults[i] for i in passed_indices] - # f.write(json.dumps(item) + "\n") - - print(f"Uncompilable: {uncompilable}") - print(f"Failed: {nfails}") + idx, response, code, passed, output = future_result + if "Failed to execute program" in output: + if not args.save_request_errors: + continue + newdata = form_new_data( + item=raw_data[idx], + response=response, + extracted_code=code, + pass_execution=passed, + output=output, + ) + f.write(json.dumps(newdata) + "\n") + n_passed += passed + n_passed_inner += passed + pbar_inner.set_description( + pass_rate_str(n_passed_inner, n_processed_inner) + ) + pbar.set_description(pass_rate_str(n_passed, n_processed)) + if cleanup_command is not None: + print(f"Cleaning up: {cleanup_command}") + os.system(cleanup_command) + print("Cleanup done.") + + n_total_passed = n_cached_passed + n_passed + n_total = len(all_tasks) + n_cached_total + print(pass_rate_str(n_total_passed, n_total, "Total")) if __name__ == "__main__": - print("Try to run this file using docker if possible!") - Fire(main) + main() diff --git a/src/star_align/llm_wrapper.py b/src/star_align/llm_wrapper.py index 669d3e0..426b55c 100644 --- a/src/star_align/llm_wrapper.py +++ b/src/star_align/llm_wrapper.py @@ -440,6 +440,13 @@ def get_model_context( other_kwargs["attn_implementation"] = "flash_attention_2" # other_kwargs["use_flash_attention_2"] = True # cls = AutoModelWithLMHead if "starcoder2-3b" in model_key else AutoModelForCausalLM + + if "starcoder" in model_key.lower(): + print("Hack for starcoder") + attention_dropout = attention_dropout or 0.0 + residual_dropout = residual_dropout or 0.0 + embedding_dropout = embedding_dropout or 0.0 + if attention_dropout is not None: other_kwargs["attention_dropout"] = attention_dropout if residual_dropout is not None: diff --git a/src/star_align/minhash_dedup.py b/src/star_align/minhash_dedup.py new file mode 100644 index 0000000..bb9b36b --- /dev/null +++ b/src/star_align/minhash_dedup.py @@ -0,0 +1,428 @@ +import json +import argparse +import random +import multiprocessing as mp +import os +import re +from collections import defaultdict +from typing import Any, Callable, List + +import click +import datasets +import numpy as np +from tqdm import tqdm + +import pickle # nosec +from collections import Counter +from pathlib import Path +from itertools import tee + +from scipy.integrate import quad as integrate + +import hashlib +import struct +from hashlib import md5 +from hashlib import sha256 + +import xxhash +from xxhash import xxh3_64 +from xxhash import xxh3_64_digest +from xxhash import xxh3_128 +from xxhash import xxh3_128_digest + + +parser = argparse.ArgumentParser() +# IO Args +parser.add_argument("--data_files", type=str, required=True) +parser.add_argument("--output", type=str, required=True) +parser.add_argument("--num_proc", type=int, default=os.cpu_count()) +# Meta Args +parser.add_argument("--column", type=str, required=True) +parser.add_argument("--batch_size", type=int, default=10_000) +# MinHash Args +parser.add_argument("--ngram", type=int, default=5) +parser.add_argument("--min_length", type=int, default=5) +parser.add_argument("--ignore_empty", type=bool, default=False) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument("--num_perm", type=int, default=250) +parser.add_argument("--threshold", type=float, default=0.7) +parser.add_argument("--b", type=int, default=None) +parser.add_argument("--r", type=int, default=None) +parser.add_argument("--hash_func", type=str, default="sha1") +parser.add_argument("--hash_bits", type=int, default=64) +args = parser.parse_args() + + +def ngrams(sequence: List[str], n: int, min_length: int = 5): + """ + Return the ngrams generated from a sequence of items, as an iterator. + + This is a modified version of nltk.util.ngrams. + """ + if len(sequence) < min_length: + return [] + if len(sequence) < n: + return [tuple(sequence)] + iterables = tee(iter(sequence), n) + for i, sub_iterable in enumerate(iterables): + for _ in range(i): + next(sub_iterable, None) + return zip(*iterables) + + +class UnionFind: + """ + A data structure for maintaining disjoint sets. This helps build connected components for given duplicate pairs. + """ + + def __init__(self): + self.parent = {} + # Counter is a subclass of dict with slightly different python and c implementations + # you can think of it as an optimized defaultdict(int) + self.rank = Counter() + + def find(self, x): + try: + # path compression + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + except KeyError: + # KeyError happens if x not in parent + self.parent[x] = x + finally: + return self.parent[x] + + def union(self, x, y): + px = self.find(x) + py = self.find(y) + + # If both elements are already in the same set, do nothing + # The line in original UnionFind `self.parent[px] = self.parent[py] = min(px, py)` is redundant when px == py + if px == py: + return + + if self.rank[px] == self.rank[py]: + # If ranks are equal, choose one as the new root and increment its rank + # with few duplicates this is likely to be the most common case + self.parent[py] = px + self.rank[px] += 1 + # otherwise, assume that leftside is more likely to be higher rank + # Attach the smaller rank tree under the root of the larger rank tree + elif self.rank[px] > self.rank[py]: + self.parent[py] = px + else: + self.parent[px] = py + + def reset(self): + self.parent = {} + self.rank = Counter() + + def dump(self, path: str | Path, id2id=None): + if id2id is not None: + new_uf = UnionFind() + for i in self.parent: + new_uf.union(id2id[i], id2id[self.find(i)]) + else: + new_uf = self + + with open(path, "wb") as f: + pickle.dump(new_uf, f, protocol=pickle.HIGHEST_PROTOCOL) + + +RNG = np.random.RandomState(args.seed) +NON_ALPHA = re.compile(r"\W", re.UNICODE) +datasets.logging.set_verbosity_error() + +SIGNATURE_COLUMN = "__signatures__" +INDEX_COLUMN = "__index__" +CLUSTER_COLUMN = "__cluster__" + +# for is originally used to reduce memory usage in MacOS but also ensures that the Union Find data structure +# is not copied to child processes as long as it is not modified. +mp.set_start_method("fork", force=True) +uf = UnionFind() + + +def sha1_hash(data: bytes, d: int = 32) -> int: + """ + Generate a d-bit hash value from the given data. + """ + if d == 32: + return struct.unpack( + " int: + """ + Generate a 16-bit xxhash based hash value from the given data. + As of python xxhash 3.3.0 (and since 0.3.0) outputs in big-endian. + This is useful as a special purpose xxhash when you only want 16 bits. + bit masked xxh3_64 hashes are faster than xxh32 in modern systems. + """ + return xxhash.xxh3_64_intdigest(data, seed) & 0xFFFF + + +def xxh3_32hash(data: bytes, seed: int = 0) -> int: + """ + Generate a 32-bit xxhash based hash value from the given data. + As of python xxhash 3.3.0 (and since 0.3.0) outputs in big-endian. + This is useful as a special purpose xxhash when you only want 32bits. + bit masked xxh3_64 hashes are faster than xxh32 in modern systems. + """ + return xxhash.xxh3_64_intdigest(data, seed) & 0xFFFFFFFF + + +def optimal_param( + threshold: float, + num_perm: int, + false_positive_weight: float = 0.5, + false_negative_weight: float = 0.5, +): + """ + Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum + of probabilities of false positive and false negative, taken from datasketch. + """ + + def false_positive_area(threshold: float, b: int, r: int): + """Source: `datasketch.lsh`""" + + def proba(s): + return 1 - (1 - s ** float(r)) ** float(b) + + a, _ = integrate(proba, 0.0, threshold) + return a + + def false_negative_area(threshold: float, b: int, r: int): + """Source: `datasketch.lsh`""" + + def proba(s): + return 1 - (1 - (1 - s ** float(r)) ** float(b)) + + a, _ = integrate(proba, threshold, 1.0) + return a + + min_error = float("inf") + opt = (0, 0) + for b in range(1, num_perm + 1): + max_r = int(num_perm / b) + for r in range(1, max_r + 1): + fp = false_positive_area(threshold, b, r) + fn = false_negative_area(threshold, b, r) + error = fp * false_positive_weight + fn * false_negative_weight + if error < min_error: + min_error = error + opt = (b, r) + return opt + + +def embed_func( + content: str, + idx: int, + *, + num_perm: int, + ngram_size: int, + min_length: int, + hashranges: list[tuple[int, int]], + permutations: np.ndarray, + hash_func: Callable, + dtype: type, + max_hash: np.uint, + modulo_prime: np.uint, +) -> dict[str, Any]: + """ + Calculate hash values for the content. + """ + # a, b are each np.ndarray arrays containing {num_perm} pairs of random numbers used for building new hashes + # the formula is a * x(base hash of each shingle) + b + a, b = permutations + # split content on whitespace (NON_ALPHA regex), tokenize with ngrams(), and join these n-grams into a single space separated string. + # we then convert to lower case and then bytestrings which is then hashed. Only unique hashed n-grams are left. + tokens: set[bytes] = { + bytes(" ".join(t).lower(), "utf-8") + for t in ngrams(NON_ALPHA.split(content.lower()), ngram_size, min_length) + } + + hashvalues: np.ndarray = np.array( + [hash_func(token) for token in tokens], dtype=dtype + ).reshape(len(tokens), 1) + # Permute the hash values to produce new universal hashes + # Element-wise multiplication with 'hashvalues' and a (non 0 random value) and then adding b + # Then, take modulo 'MODULO_PRIME' and bitwise_and with 'MAX_HASH' to keep only the necessary bits. + hashvalues = (hashvalues * a + b) % modulo_prime & max_hash + # this part is where the name "min" of minhash comes from + # this stacks all the hashes and then takes the minimum from each column + masks: np.ndarray = np.full(shape=num_perm, dtype=dtype, fill_value=max_hash) + hashvalues = np.vstack([hashvalues, masks]).min(axis=0) + # Originally, byteswap was done for speed. Testing show it has a negligible impact + # keeping for backward compatibility, even though theoretically and empirically + # it doesnt matter if it is there or not. github.com/ekzhu/datasketch/issues/114 + Hs: list[bytes] = [ + bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges + ] + return {SIGNATURE_COLUMN: Hs, INDEX_COLUMN: idx} + + +def main(): + global uf + uf.reset() + HASH_BITS: int = args.hash_bits + HASH_CONFIG: dict[int, tuple[type, Any, Any]] = { + 64: (np.uint64, np.uint32((1 << 32) - 1), np.uint64((1 << 61) - 1)), + # 32, 16 bit config does not use a mersenne prime. + # The original reason for using mersenne prime was speed. + # Testing reveals, there is no benefit to using a 2^61 mersenne prime for division + 32: (np.uint32, np.uint32((1 << 32) - 1), np.uint32((1 << 32) - 5)), + 16: (np.uint16, np.uint16((1 << 16) - 1), np.uint16((1 << 16) - 15)), + } + DTYPE, MAX_HASH, MODULO_PRIME = HASH_CONFIG.get(HASH_BITS, HASH_CONFIG[64]) + + match args.hash_func: + case "sha1": + + def hash_func(byte_data): + return sha1_hash(byte_data, d=min(HASH_BITS, 32)) + + case "xxh3": + if HASH_BITS == 16: + hash_func = xxh3_16hash + else: + hash_func = xxh3_32hash + + if args.b is not None and args.r is not None: + B, R = args.b, args.r + else: + # Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum + # of probabilities of false positive and false negative, taken from datasketch. + B, R = optimal_param( + args.threshold, + args.num_perm, + false_positive_weight=0.5, + false_negative_weight=0.5, + ) + + HASH_RANGES = [(i * R, (i + 1) * R) for i in range(B)] + HASH_TABLES = [defaultdict(set) for _ in range(B)] + + PERMUTATIONS = ( + RNG.randint( + 1, MODULO_PRIME, size=(args.num_perm,), dtype=DTYPE + ), # a is a multiplier so should not be 0 + RNG.randint(0, MODULO_PRIME, size=(args.num_perm,), dtype=DTYPE), # b + ) + + # Loading + data_files_list = [x.strip() for x in args.data_files.split(",")] + ds = datasets.load_dataset("json", data_files=data_files_list, split="train") + ds = ds.map( + lambda x, i: {INDEX_COLUMN: i}, with_indices=True, num_proc=args.num_proc + ) + + if args.ignore_empty: + ds_rest = ds.filter(lambda x: len(x[args.column].strip()) == 0) + ds = ds.filter(lambda x: len(x[args.column].strip()) > 0) + + ds = ds.filter( + lambda x: len(NON_ALPHA.split(x[args.column].lower())) >= args.min_length, + num_proc=args.num_proc, + ) + + LEN_DATASET = len(ds) + if args.ignore_empty: + LEN_DATASET += len(ds_rest) + + # MinHashing + embedded = ds.map( + function=embed_func, + fn_kwargs={ + "num_perm": args.num_perm, + "hashranges": HASH_RANGES, + "ngram_size": args.ngram, + "min_length": args.min_length, + "permutations": PERMUTATIONS, + "hash_func": hash_func, + "dtype": DTYPE, + "max_hash": MAX_HASH, + "modulo_prime": MODULO_PRIME, + }, + input_columns=[args.column, INDEX_COLUMN], + remove_columns=[col for col in ds.column_names if col != INDEX_COLUMN], + num_proc=args.num_proc, + with_indices=False, + desc="Fingerprinting...", + ) + LEN_EMBEDDED = len(embedded) + NUM_SHARDS = np.ceil(LEN_EMBEDDED / args.batch_size).astype(int) + + # Clustering + edges = [] + for i in tqdm( + range(0, NUM_SHARDS), + dynamic_ncols=True, + desc="Iterating MinHashes...", # noqa: E501 + ): + embedded_shard = embedded.shard( + num_shards=NUM_SHARDS, + index=i, + contiguous=True, + writer_batch_size=args.batch_size, + ) + for key, Hs in zip( + embedded_shard[INDEX_COLUMN], embedded_shard[SIGNATURE_COLUMN] + ): + for i, H in enumerate(Hs): + HASH_TABLES[i][H].add(key) + + print(f"Number of clusters: {len(HASH_TABLES)}") + for table in tqdm(HASH_TABLES, dynamic_ncols=True, desc="Clustering..."): + # cluster: Set[int] + for cluster in table.values(): + if len(cluster) <= 1: + continue + idx = min(cluster) + for x in cluster: + edges.append((x, idx)) + uf.union(x, idx) + print(f"Number of edges: {len(set(edges))}") + + # Filtering + ds = ds.map( + function=lambda record: {CLUSTER_COLUMN: uf.find(record[INDEX_COLUMN])}, + with_indices=False, + num_proc=args.num_proc, + new_fingerprint=str(random.getrandbits(128)), + desc="Finding clusters...", + ) + # This is where the deduplication happens + # Since there is no easy groupby in datasets + # I will use this simple filter for now + final_data = ds.filter( + function=lambda record: record[CLUSTER_COLUMN] == record[INDEX_COLUMN], + with_indices=False, + num_proc=args.num_proc, + desc="Filtering clusters...", + ) + if args.ignore_empty and len(ds_rest) > 0: + final_data = datasets.concatenate_datasets([ds_rest, final_data]) + + # Saving + final_data = final_data.remove_columns([CLUSTER_COLUMN, INDEX_COLUMN]) + final_data.to_json(args.output) + print("Before:", LEN_DATASET) + print("After:", len(final_data)) + + # Cleaning + ds.cleanup_cache_files() + final_data.cleanup_cache_files() + + +if __name__ == "__main__": + main() diff --git a/src/star_align/sanitize_data.py b/src/star_align/sanitize_data.py index fa7081e..f0c9bee 100644 --- a/src/star_align/sanitize_data.py +++ b/src/star_align/sanitize_data.py @@ -1,33 +1,86 @@ -import os +"""Deduplication, filtering, and selection""" + import random -import sys - -from star_align.utils import ( - find_code_blocks, - read_jsonl, - remove_comments_from_code_blocks, - write_jsonl, -) - -src = sys.argv[1] -tgt = sys.argv[2] - -xs = read_jsonl(src) -print("Before:", len(xs)) - -random.seed(0) - -removing_tests = os.getenv("NO_TESTS") is not None -removing_explanation = os.getenv("NO_EXPL") is not None -removing_reasoning = os.getenv("NO_REASONING") is not None -removing_comments = os.getenv("NO_COMMENTS") is not None -removing_incomplete = os.getenv("NO_INCOMPLETE") is not None -codeonly = os.getenv("CODEONLY") is not None -augmentation_prob = float(os.getenv("AUGMENTATION", 0.0)) -keep_raw_format = os.getenv("RAW") is not None -smart = os.getenv("SMART") is not None - -incomplete_substrings = [ +import os +import ast +import re +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import cast, Literal +from datasets import load_dataset, Dataset +from tqdm.auto import tqdm +from transformers import HfArgumentParser + +from star_align.utils import find_code_blocks, write_jsonl, find_codeblock_indices + +LLAMA3 = os.getenv("LLAMA3") is not None +if LLAMA3: + print("LLAMA3 mode activated") + + +@dataclass(frozen=True) +class Args: + data_files: list[str] + output_file: str + shuffle: bool = field(default=True) + remove_strange: bool = field(default=True) + parse_raw_response: bool = field(default=True) + passing_only: bool = field(default=True) + data_augmentation: bool = field(default=False) + exact_match_dedup: bool = field(default=True) + get_code_representation: bool = field(default=True) + remove_comments_docstrings: bool = field(default=False) + include_left_failed: bool = field(default=False) + n_cores: int = field(default=os.cpu_count() or 1) + diversify_func_names: bool = field(default=True) + align_with: list[str] = field(default_factory=list) + priority: Literal["passed", "failed", "none"] = field(default="none") + seed: int = field(default=6666) + + +def extract_and_concat_function_names(python_content): + """ + Extracts all function names from a given Python content string and concatenates them into a single string. + + Parameters: + - python_content: A string containing the Python code to analyze. + + Returns: + - A string containing all function names defined in the content, concatenated. + """ + tree = ast.parse(python_content) + function_names = [] + class_names = [] + + # Define a node visitor that adds the name of each function definition it visits + class FuncClassDefVisitor(ast.NodeVisitor): + def visit_ClassDef(self, node: ast.ClassDef): + class_names.append(node.name) + self.generic_visit(node) + + def visit_FunctionDef(self, node): + function_names.append(node.name) + # Process the subtree for this node + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node): + function_names.append(node.name) + self.generic_visit(node) + + # Create a node visitor and walk through the AST + visitor = FuncClassDefVisitor() + visitor.visit(tree) + + def compress_name(name: str) -> str: + return name.replace("_", "").lower() + + return frozenset(map(compress_name, function_names)), frozenset( + map(compress_name, class_names) + ) + + +INCOMPLETE_SUBSTRINGS = [ "todo", "fixme", "write your code here", @@ -36,157 +89,360 @@ "notimplemented", ] -if removing_tests: - print("Removing tests") -if removing_explanation: - print("Removing explanation") -if removing_reasoning: - print("Removing reasoning") -if removing_comments: - print("Removing comments") -if removing_incomplete: - print("Removing incomplete") -if codeonly: - print("Code only") -if augmentation_prob > 0: - print("Augmentation prob:", augmentation_prob) -if keep_raw_format: - print("Keeping raw format") -if smart: - removing_comments = True - augmentation_prob = 0.5 - print("Smart mode") - - -def filter_x(x): - response = x["response"] - tags = ["[Tests]", "[Reasoning]", "[Explanation]", "[Implementation]"] - # response = "".join(l for l in response.splitlines(keepends=True) if l.strip().startswith("")) - if any(response.count(tag) != 1 for tag in tags): - return False - tests_index = response.index("[Tests]") - explanation_index = response.index("[Explanation]") - reasoning_index = response.index("[Reasoning]") - implementation_index = response.index("[Implementation]") - if not (reasoning_index < implementation_index < explanation_index < tests_index): - return False - reasoning = response[ - reasoning_index + len("[Reasoning]") : implementation_index - ].strip() - explanation = response[ - explanation_index + len("[Explanation]") : tests_index - ].strip() - implementation = response[ - implementation_index + len("[Implementation]") : explanation_index - ].strip() - codeblocks = find_code_blocks(implementation, "python") - if len(codeblocks) == 0: - return False - if codeonly: - code = "\n\n".join(codeblocks) - implementation = f"```python\n{code}\n```" - tests = response[tests_index + len("[Tests]") :].strip() - # tests = tests.split("\n[")[0].split("\n##")[0].strip() - tests_blocks = find_code_blocks(tests, "python") - if len(tests_blocks) != 1 or tests.count("```") != 2: - if os.getenv("DEBUG"): - breakpoint() - return False - assert tests.count("```") == 2 - index = tests.rindex("```") + 3 - tests_prefix = tests[:index] - # tests_suffix = tests[index:] - # tests_suffix = tests_suffix.split("\n[")[0].split("\n#")[0].rstrip() - # tests = tests_prefix + tests_suffix - # remove NL after test block - tests = tests_prefix - strange_index = next( - (idx for idx, l in enumerate(tests.splitlines()) if l.startswith("/")), - None, - ) - if strange_index is not None: - # print("MD Index:", strange_index) - tests = "".join(tests.splitlines(keepends=True)[:strange_index]).strip() +RESPONSE_TEST_SPLIT = "\n\n" +# special handling for llama3 since it has more examples not following the format +LLAMA3_DEFAULT_TEST_SPLIT = r"### Tests \d\n" +LLAMA3_ADDITIONAL_PATTERNS = [ + "We can verify the functionality", + "We can verify the correctness", + "You can verify the correctness", + "You can verify the functionality", + "To ensure the correctness", + "To verify the correctness", + "To test the", + "To test this", + "To test this", + "You can test the", + "We can test the", + "We can test this", + "Now, we'll test", +] + + +def split_llama3_response_tests(response: str) -> list[str]: + splits = re.split(LLAMA3_DEFAULT_TEST_SPLIT, response) + if len(splits) > 2: + return [] + if len(splits) == 2: + return splits + for pattern in LLAMA3_ADDITIONAL_PATTERNS: + index = response.find(pattern) + if index != -1: + return [response[:index], response[index:]] + return [] + + +def preprocess_and_filter(x: dict) -> dict: + """Filter out responses with wrong format""" + + def wrong_format(x: dict) -> dict: + return {k: v for k, v in x.items()} | dict(wrong_format=True, tests="") + + response: str = x["response"] + if not LLAMA3 and RESPONSE_TEST_SPLIT not in response: + return wrong_format(x) + if any(substring in response.lower() for substring in INCOMPLETE_SUBSTRINGS): + return wrong_format(x) + if LLAMA3: + splits = split_llama3_response_tests(response) + else: + splits = response.split(RESPONSE_TEST_SPLIT) + if len(splits) != 2: + return wrong_format(x) + response, tests = cast(tuple[str, str], tuple(map(str.strip, splits))) + response_codeblocks = find_code_blocks(response, "python") + tests_codeblocks = find_code_blocks(tests, "python") + if len(response_codeblocks) == 0 or len(tests_codeblocks) == 0: + return wrong_format(x) + + tests_content = "\n".join(tests_codeblocks) if "assert" not in tests or all( l.startswith("def") or l.startswith("class") or l.startswith("import") or l.startswith("from") - for l in tests_blocks[0].splitlines() + for l in tests_content.splitlines() if len(l) > 0 and l[0].isalpha() ): - return False - - global removing_tests, removing_explanation, removing_reasoning, removing_incomplete - - if smart: - removing_tests = "test" not in x["instruction"].lower() - removing_explanation = random.random() < 0.5 - removing_reasoning = random.random() < 0.5 - removing_incomplete = True - - if keep_raw_format: - contents = [ - "[Reasoning]\n" + reasoning, - "[Implementation]\n" + implementation, - "[Explanation]\n" + explanation, - "[Tests]\n" + tests, - ] - else: - contents = [reasoning, implementation, explanation, tests] + return wrong_format(x) - if removing_incomplete: - if any( - substring in x["response"].lower() for substring in incomplete_substrings - ): - return False - - if removing_tests: - contents.remove(tests) - if removing_explanation: - contents.remove(explanation) - if removing_reasoning: - contents.remove(reasoning) - x["response"] = "\n\n".join(contents) - - tests_block = tests_blocks[0] - lines = tests_block.splitlines() + newx = {k: v for k, v in x.items() if k != "response"} | dict( + response=response, tests=tests, wrong_format=False + ) + return newx + + +def augment_data(x: dict, index: int) -> dict: + random.seed(index) + tests_content = "\n".join(find_code_blocks(x["tests"])) + lines = tests_content.splitlines() if all(l.startswith("assert") for l in lines): ks = [1, 2, 3, 4, 5] assertions = random.sample(lines, k=min(random.choice(ks), len(lines))) assertion = "\n".join(assertions) assertion_term = "assertion" + ("s" if len(assertions) > 1 else "") else: - assertion = tests_block + assertion = tests_content assertion_term = "test case" if ( - augmentation_prob > 0 - and "assert" in assertion + "assert" in assertion # 5 lines augmented block max and len(assertion.splitlines()) <= 5 - and random.random() < augmentation_prob + and random.random() < 0.5 and "assert" not in x["instruction"] and "for example" not in x["instruction"].lower() - and (not smart or "test" not in x["instruction"].lower()) + and "test" not in x["instruction"].lower() ): - # if smart: - # contents.remove(tests) - # else: - assert removing_tests assert "assert" in assertion assertion_str = ( f"Your code should pass the following {assertion_term}:\n```python\n" + assertion.strip() + "\n```" ) - x["instruction"] = f"{x['instruction']}\n\n{assertion_str}" - if removing_comments: - x["response"] = remove_comments_from_code_blocks(x["response"]) - # for tag in tags: - # x["response"] = x["response"].replace(f"{tag}\n", "") - return True + new_instruction = f"{x['instruction']}\n\n{assertion_str}" + newx = {k: v for k, v in x.items()} | dict(instruction=new_instruction) + return newx + return x + + +# raw response -> response + test +# response/test -> passing (opt: passing) +# (not)passing -> unique +# unique -> aug / minihash / selection / educational -> final + + +def remove_ast_docstrings(tree): + # ref: https://gist.github.com/phpdude/1ae6f19de213d66286c8183e9e3b9ec1 + for node in ast.walk(tree): + # let's work only on functions & classes definitions + if not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)): + continue + if len(node.body) == 0: + continue + if not isinstance(node.body[0], ast.Expr): + continue + if ( + not hasattr(node.body[0], "value") + or not isinstance(node.body[0].value, ast.Str) + # or not isinstance(node.body[0].value.value, str) + ): + continue + node.body = node.body[1:] # type: ignore + return tree + + +def remove_comments_from_code_blocks( + content: str, +) -> str: + code_blocks = find_codeblock_indices(content) + # Current index in the original content for tracking purposes + current_index = 0 + # Buffer to store the new content + new_content: list[str] = [] + # Iterate over each code block + for start, end in code_blocks: + # Append the content before this code block + new_content.append(content[current_index:start]) + + # Extract the code block content + code_block_content = content[start:end] + + # Split into lines, process, and rejoin + modified_block_content = remove_comments(code_block_content) + + new_content.append(modified_block_content) + + # Update current index + current_index = end + + # Add the remaining part of the original content after the last code block + new_content.append(content[current_index:]) + + # Join all parts to form the final modified content + return "".join(new_content) + + +def remove_comments(code: str) -> str: + """Remove comments and docstrings using AST""" + tree = ast.parse(code) + tree = remove_ast_docstrings(tree) + return ast.unparse(tree) + + +def get_code_representation(response: str) -> str: + """Keep classes and functions, removing comments and docstrings""" + raw_code = "\n".join(find_code_blocks(response)) + + tree = ast.parse(raw_code) + + class ClassFunctionTransformer(ast.NodeTransformer): + def visit_Module(self, node): + # Visit all children nodes of the module + node = self.generic_visit(node) + # Filter out only function and class definitions + node.body = [ + n for n in node.body if isinstance(n, (ast.FunctionDef, ast.ClassDef)) + ] + return node + + visitor = ClassFunctionTransformer() + tree = visitor.visit(tree) + tree = remove_ast_docstrings(tree) + return ast.unparse(tree) + + +def map_code_representation(x: dict) -> dict: + try: + representation = get_code_representation(x["response"]) + except SyntaxError: + representation = "" + return {k: v for k, v in x.items()} | dict(code_representation=representation) + + +# def concat_list(lists: list[list]) -> list: +# return [item for sublist in lists for item in sublist] + + +def map_examples_batched(examples: dict, map_one) -> dict: + all_keys = list(examples.keys()) + list_of_examples = [ + {k: examples[k][i] for k in all_keys} for i in range(len(examples[all_keys[0]])) + ] + results = [map_one(example) for example in list_of_examples] + result_dict = {k: [result[k] for result in results] for k in results[0].keys()} + return result_dict + + +def map_remove_comments(x: dict) -> dict: + try: + response = x["response"] + except SyntaxError: + response = "" + return {k: v for k, v in x.items() if k != "response"} | dict(response=response) + + +def main(): + args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + + raw_data = load_dataset("json", data_files=args.data_files, split="train") + if args.align_with: + ref_data = load_dataset("json", data_files=args.align_with, split="train") + ref_data_instructions = set(map(lambda x: x["instruction"], ref_data)) + raw_data = raw_data.filter( + lambda x: x["instruction"] in ref_data_instructions, num_proc=args.n_cores + ) + print("Raw samples:", len(raw_data)) + + if args.parse_raw_response: + raw_data = raw_data.map( + map_examples_batched, + fn_kwargs=dict(map_one=preprocess_and_filter), + batched=True, + num_proc=args.n_cores, + ) + raw_data = raw_data.filter( + lambda x: not x["wrong_format"], num_proc=args.n_cores + ) + raw_data = raw_data.remove_columns(["wrong_format"]) + print("Correct format:", len(raw_data)) + + if args.include_left_failed: + failed_data = raw_data.filter(lambda x: not x["pass"], num_proc=args.n_cores) + + if args.passing_only: + raw_data = raw_data.filter(lambda x: x["pass"], num_proc=args.n_cores) + print("Passing only:", len(raw_data)) + + if args.shuffle: + raw_data = raw_data.shuffle(seed=args.seed) + if args.include_left_failed: + failed_data = failed_data.shuffle(seed=args.seed) + + if args.priority != "none": + # Sort the examples such that failed/passed are at first + raw_data = raw_data.map( + map_examples_batched, + fn_kwargs=dict(map_one=lambda x: dict(**x, rank=int(x["pass"]))), + batched=True, + num_proc=args.n_cores, + ) + reverse = args.priority == "passed" + raw_data = raw_data.sort(column_names="rank", reverse=reverse) + raw_data = raw_data.remove_columns("rank") + + def mk_key(instruction: str) -> str: + return "".join(instruction.split()) + + seen_ids = set[frozenset[str]]() + seen_keys = set[str]() + if args.exact_match_dedup: + new_data = list[dict]() + + def iterate(dataset: Dataset): + for d in tqdm(dataset): + if args.remove_strange: + # NOTE: newly added + if len(d["instruction"].split()) > 200: + continue + key_i, key_r = mk_key(d["instruction"]), mk_key(d["response"]) + if key_i in seen_keys or key_r in seen_keys: + continue + if args.diversify_func_names: + code_block = find_code_blocks(d["response"])[0] + try: + fn_names, class_names = extract_and_concat_function_names( + code_block + ) + except SyntaxError: + continue + if (len(fn_names) > 0 and fn_names in seen_ids) or ( + len(class_names) > 0 and class_names in seen_ids + ): + continue + seen_ids.add(fn_names) + seen_ids.add(class_names) + new_data.append(d) + seen_keys.add(key_i) + seen_keys.add(key_r) + + iterate(raw_data) + if args.include_left_failed: + iterate(failed_data) + + print("Non exact matches:", len(new_data)) + else: + new_data = raw_data.to_list() + if args.include_left_failed: + new_data.extend(failed_data.to_list()) + new_dataset = Dataset.from_list(new_data) + + if args.get_code_representation: + new_dataset = new_dataset.map( + map_examples_batched, + fn_kwargs=dict(map_one=map_code_representation), + batched=True, + batch_size=1000, + # num_proc=args.n_cores, + ) + new_dataset = new_dataset.filter( + lambda x: x["code_representation"] != "", + num_proc=args.n_cores, + ) + print("Extracted code representation:", len(new_dataset)) + + if args.remove_comments_docstrings: + new_dataset = new_dataset.map( + map_examples_batched, + fn_kwargs=dict(map_one=map_remove_comments), + batched=True, + # num_proc=args.n_cores, + ) + new_dataset = new_dataset.filter( + lambda x: x["response"] != "", + num_proc=args.n_cores, + ) + print("Removed comments/docstrings:", len(new_dataset)) + + if args.data_augmentation: + new_dataset = new_dataset.map( + augment_data, + num_proc=args.n_cores, + with_indices=True, + ) + print("Augmented:", len(new_dataset)) + + write_jsonl(Path(args.output_file), new_dataset) -xs = [x for x in xs if filter_x(x)] -print("After:", len(xs)) -write_jsonl(tgt, xs) +if __name__ == "__main__": + main() diff --git a/src/star_align/self_ossinstruct.py b/src/star_align/self_ossinstruct.py index 1ea11a0..db7f171 100644 --- a/src/star_align/self_ossinstruct.py +++ b/src/star_align/self_ossinstruct.py @@ -7,14 +7,14 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, cast - +from openai.types import CompletionChoice, Completion from datasets import Dataset, load_dataset from tqdm.auto import tqdm from transformers import HfArgumentParser import star_align -InstructMode = Literal["I->R", "S->C", "C->I"] +InstructMode = Literal["I->R", "S->C", "C->I", "S->I"] LANGUAGE_MAP = { "cpp": "C++", @@ -25,6 +25,29 @@ "typescript": "TypeScript", } +LLAMA3 = os.getenv("LLAMA3") is not None + +if LLAMA3: + print("Use Llama-3 prompt format") + + +def flatten_openai_responses(responses: list[Completion]) -> list[Completion]: + # assert all(len(response.choices) == chunk_size for response in responses) + completions = list[Completion]() + for idx, response in enumerate(responses): + completions.extend( + Completion( + id=f"{response.id}:{idx}", + created=response.created, + object=response.object, + model=response.model, + choices=[choice], + system_fingerprint=response.system_fingerprint, + ) + for choice in response.choices + ) + return completions + @dataclass(frozen=True) class Args: @@ -48,6 +71,15 @@ class Args: prompting_mode: Literal["chat", "completion"] = field(default="completion") num_fewshots: int = field(default=8) + async_micro_batch_size: int = field( + default=1, + metadata={ + "help": "Number of prompts to batch in a single async request." + "Won't be effective if `use_vllm_server` is False." + "It should be divisible by `num_batched_requests` for a balance load" + "if multiple vllm instances are served through a round-robin load balancer." + }, + ) num_batched_requests: int = field( default=1, metadata={"help": "Number of requests to send concurrently"} ) @@ -95,9 +127,10 @@ class Property: def random_exercise(concepts: list[str], language: str) -> "Property": category = random.choice( [ - "code generation (function implementation)", - # "code generation (class implementation)", - # "code generation (program implementation)", + "function implementation", + "function implementation", + "class implementation", + "program implementation", ] ) difficulty = random.choice(["easy", "medium", "hard"]) @@ -114,9 +147,9 @@ def concepts_prompt(self) -> str: def prompt(self) -> str: category = f"category: {self.category}" language = f"language: {self.language}" - concepts = f"concepts: {self.concepts_prompt()}" difficulty = f"difficulty: {self.difficulty}" - return "\n".join([category, language, concepts, difficulty]) + concepts = f"concepts: {self.concepts_prompt()}" + return "\n".join([category, language, difficulty, concepts]) def to_json(self) -> dict[str, str | list[str]]: return dict( @@ -145,31 +178,58 @@ class Example: property: Property snippet: str instruction: str - solution: str | None response: str + tests: str @staticmethod def prefix_template(mode: InstructMode) -> str: if mode == "I->R": - return "### Instruction\n{instruction}\n\n### Response\n" + if LLAMA3: + template = ( + "### Instruction {index} \n{instruction}\n\n### Response {index}\n" + ) + else: + template = ( + "\n{instruction}\n\n\n\n" + ) + return template elif mode == "S->C": return "### Snippet\n{snippet}\n\n### Concepts\n" elif mode == "C->I": - return "### Property\n{property}\n\n### Instruction\n" + return "### Properties\n{property}\n\n### Task\n" + elif mode == "S->I": + return "### Snippet\n{snippet}\n\n### Task\n" else: assert False def prompt( - self, mode: InstructMode, return_in_separate: bool = False + self, + mode: InstructMode, + return_in_separate: bool = False, + index: int | None = None, ) -> str | tuple[str, str]: + assert index is None or (mode == "I->R" and LLAMA3) if mode == "I->R": kwargs = dict(instruction=self.instruction) - suffix = self.response + if LLAMA3: + assert index is not None + kwargs["index"] = str(index) + suffix = f"{self.response}\n\n### Tests {index}\n{self.tests}" + else: + suffix = ( + f"{self.response}\n\n\n\n{self.tests}\n" + ) elif mode == "S->C": kwargs = dict(snippet=self.snippet) suffix = self.property.concepts_prompt() elif mode == "C->I": - kwargs = dict(property=self.property.prompt()) + property_prompt = self.property.prompt() + # num_words = len(self.instruction.split()) + # property_prompt += f"\nnum_words: {num_words}" + kwargs = dict(property=property_prompt) + suffix = self.instruction + elif mode == "S->I": + kwargs = dict(snippet=self.snippet) suffix = self.instruction else: assert False @@ -185,6 +245,7 @@ class Fewshot: sys_i_r: str sys_c_i: str sys_s_c: str + sys_s_i: str examples: list[Example] @@ -193,10 +254,10 @@ def system_prompt(self, mode: InstructMode) -> str: return getattr(self, attr_name) def valid_examples(self, mode: InstructMode) -> list[Example]: - if mode in ["E->S", "I->RT", "I->R"]: - return [ - example for example in self.examples if example.solution is not None - ] + # if mode in ["E->S", "I->RT", "I->R"]: + # return [ + # example for example in self.examples if example.solution is not None + # ] return self.examples def random_prompt( @@ -210,19 +271,39 @@ def random_prompt( assert ( 0 < num_fewshots <= len(valid_examples) ), f"{num_fewshots=}, {len(valid_examples)=}" + # if mode == "I->R": + # # Hack + # category = format_args["category"] + # matching_examples = [ + # example + # for example in valid_examples + # if example.property.category == category + # ] + # assert len(matching_examples) > 0, f"{category=}" + # matching_example = random.choice(matching_examples) + # rest_of_examples = [ + # example for example in valid_examples if example is not matching_example + # ] + # assert len(rest_of_examples) == len(self.examples) - 1 + # examples = [matching_example] + random.sample( + # rest_of_examples, k=num_fewshots - 1 + # ) + # random.shuffle(examples) + # else: examples = random.sample(valid_examples, k=num_fewshots) + assert len(examples) == num_fewshots + body = "\n\n".join( - f"## Example {idx + 1}\n{example.prompt(mode)}" + f"## Example {idx + 1}\n{example.prompt(mode, index=idx + 1 if LLAMA3 and mode == 'I->R' else None)}" for idx, example in enumerate(examples) ) # content = f"{self.system_prompt}\n\n{body}" prefix_template = Example.prefix_template(mode) + if mode == "I->R" and LLAMA3: + format_args["index"] = str(len(examples) + 1) prefix = f"## Example {len(examples) + 1}\n" + prefix_template.format( **format_args ) - if mode == "E->S": - # NOTE: special handling for python exercise generation - prefix += "```python" system_prompt = self.system_prompt(mode) full_prompt = f"{system_prompt}\n\n{body}\n\n{prefix}" assert prompting_mode == "completion", "Only completion is supported for now" @@ -238,7 +319,7 @@ def parse_property(content: str) -> Property | None: lines = [line[line.index(":") + 1 :].strip() for line in lines] except ValueError: return None - category, language, concepts_str, difficulty = lines + category, language, difficulty, concepts_str = lines concepts = list(map(str.strip, concepts_str.split(","))) return Property(category, language, concepts, difficulty) @@ -249,8 +330,10 @@ def get_ossinstruct_fewshots() -> Fewshot: splits = re.split(r"### Example \d+", content) system_prompt = splits[0].strip() # "I->R", "E->S", "I->I", "PI->PI", "S->C" - sys_pattern = r"### System: I->R|### System: C->I|### System: S->C" - _, i_r, c_i, s_c = list(map(str.strip, re.split(sys_pattern, system_prompt))) + sys_pattern = r"### System: I->R|### System: C->I|### System: S->C|### System: S->I" + _, i_r, c_i, s_c, s_i = list(map(str.strip, re.split(sys_pattern, system_prompt))) + if LLAMA3: + i_r = f"{i_r}\n\nFor each '## Example' below, make sure you provide a '### Response' and a '### Tests' section." # system_prompt = re.split(r"### System: Instruction", system_prompt)[1] # instruction_system_prompt, response_system_prompt = system_prompt.split( # "### System: Response" @@ -258,22 +341,22 @@ def get_ossinstruct_fewshots() -> Fewshot: # instruction_system_prompt = instruction_system_prompt.strip() # response_system_prompt = response_system_prompt.strip() examples_str = [example.strip() for example in splits[1:]] - assert len(examples_str) == 16, len(examples_str) + assert len(examples_str) == 21, len(examples_str) examples = list[Example]() for example_str in examples_str: pattern = ( - r"\[Code\]\n|\[Property\]\n|\[Instruction\]\n|\[Solution\]\n|\[Response\]\n" + r"\[Code\]\n|\[Property\]\n|\[Instruction\]\n|\[Response\]\n|\[Tests\]\n" ) - _, snippet, property, instruction, solution, response = re.split( + _, snippet, property, instruction, response, tests = re.split( pattern, example_str ) snippet = snippet.rstrip() property = parse_property(property) assert property is not None instruction = instruction.strip() - solution = None if solution.strip() == "" else solution.strip() response = response.strip() - example = Example(property, snippet, instruction, solution, response) + tests = tests.strip() + example = Example(property, snippet, instruction, response, tests) examples.append(example) # if args.external_data is not None: # examples.extend(external_examples) @@ -281,6 +364,7 @@ def get_ossinstruct_fewshots() -> Fewshot: sys_i_r=i_r, sys_c_i=c_i, sys_s_c=s_c, + sys_s_i=s_i, examples=examples, ) @@ -293,6 +377,8 @@ def parse_generated_content(content: str, instruct_mode: InstructMode) -> dict | return dict(concepts=concepts) elif instruct_mode == "C->I": return dict(instruction=content.strip()) + elif instruct_mode == "S->I": + return dict(instruction=content.strip()) else: assert False @@ -301,13 +387,24 @@ def build_kwargs(instruct_mode: InstructMode, example: dict) -> dict[str, str]: kwargs = dict[str, str]() if instruct_mode == "I->R": kwargs["instruction"] = example["instruction"] - elif instruct_mode == "S->C": + # Hack + # category_index = example["prompt"].rindex("category: ") + len("category: ") + # category_end = example["prompt"].index("\n", category_index) + # category = example["prompt"][category_index:category_end].strip() + # kwargs["category"] = category # type: ignore + elif instruct_mode in ["S->C", "S->I"]: kwargs["snippet"] = example["seed"] elif instruct_mode == "C->I": lang = example.get("data_dir", "dummy_key_not_in_example") language = LANGUAGE_MAP.get(lang, "Python") property = Property.random_exercise(example["concepts"], language=language) - kwargs["property"] = property.prompt() + property_prompt = property.prompt() + # 45 / 152 are the min/max word lengths in the fewshot examples + # num_words = random.randint(1000, 1500) + # property_prompt += f"\nnum_words: {num_words}" + kwargs["property"] = property_prompt + # Hack + kwargs["property_obj"] = property # type: ignore else: assert False return kwargs @@ -331,6 +428,12 @@ def get_readable_prefix(instruct_mode: InstructMode, example: dict) -> str: async def main(): args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) + # Sanity check + assert args.num_batched_requests % args.async_micro_batch_size == 0 + if args.async_micro_batch_size > 1: + assert ( + args.num_sample_per_request == 1 + ), "Only support 1 sample with batched async requests" if args.use_vllm_server: openai_client = star_align.utils.OpenAIClient() @@ -341,9 +444,12 @@ async def main(): num_proc=star_align.utils.N_CORES, ) id_key = "seed" - assert len(set(d[id_key] for d in raw_dataset)) == len( - raw_dataset - ), "Duplicate seeds appear in the dataset" + if os.getenv("IGNORE_SEED_CHECK") is None: + assert len(set(d[id_key] for d in raw_dataset)) == len( + raw_dataset + ), "Duplicate seeds appear in the dataset" + else: + print("[Warning] Ignoring seed check") # Every run should produce the same data as long as the default params are not changed start_index = args.seed_code_start_index @@ -395,7 +501,6 @@ async def main(): if not args.use_vllm_server: from vllm import LLM, SamplingParams, RequestOutput - from openai.types import CompletionChoice, Completion import torch engine = LLM(args.model, tensor_parallel_size=torch.cuda.device_count()) @@ -457,6 +562,8 @@ def vllm_response_to_openai(response: RequestOutput) -> Completion: ) params["prompt"] = prompt params["stop"] = ["## Example"] + if args.instruct_mode == "I->R": + params["stop"].append("") request_params.append(params) assert len(request_params) == len(examples) print(f"Ready to make {len(request_params)} requests") @@ -466,14 +573,50 @@ def vllm_response_to_openai(response: RequestOutput) -> Completion: if args.prompting_mode == "chat" else openai_client.dispatch_completions ) - responses = await dispatch_requests(request_params, delay=args.delay) + if args.async_micro_batch_size == 1: + responses = await dispatch_requests(request_params, delay=args.delay) + else: + # Construct micro batches for async requests + assert args.num_sample_per_request == 1 + request_params_batched: list[dict[str, Any]] = [] + request_params_chunks = star_align.utils.chunked( + request_params, args.async_micro_batch_size + ) + for request_params_chunk in request_params_chunks: + request_param = { + k: v + for k, v in request_params_chunk[0].items() + if k != "prompt" + } + request_param["prompt"] = [ + req["prompt"] for req in request_params_chunk + ] + request_params_batched.append(request_param) + n_async_chunks = ( + args.num_batched_requests // args.async_micro_batch_size + ) + assert len(request_params_batched) in [ + n_async_chunks, + n_async_chunks + 1, + ], (request_params_batched, n_async_chunks) + print( + f"Ready to make {len(request_params_batched)} batched async requests" + ) + responses_batched = await dispatch_requests( + request_params_batched, delay=args.delay + ) + responses = flatten_openai_responses(responses_batched) + assert len(responses) == len(examples) else: + stop = ["## Example"] + if args.instruct_mode == "I->R": + stop.append("") sampling_params = SamplingParams( temperature=args.temperature, max_tokens=args.max_output_tokens, seed=args.seed + effective_index, n=args.num_sample_per_request, - stop=["## Example"], + stop=stop, ) vllm_responses = engine.generate(all_prompts, sampling_params) responses = list(map(vllm_response_to_openai, vllm_responses)) diff --git a/src/star_align/train.py b/src/star_align/train.py index 31405de..32e4ca8 100644 --- a/src/star_align/train.py +++ b/src/star_align/train.py @@ -1,4 +1,3 @@ -import os from dataclasses import dataclass, field from typing import cast diff --git a/src/star_align/utils.py b/src/star_align/utils.py index e10f4da..757139f 100644 --- a/src/star_align/utils.py +++ b/src/star_align/utils.py @@ -192,67 +192,52 @@ def find_codeblock_indices( return all_indices -def remove_comments_from_code_blocks( - content: str, -) -> str: - code_blocks = find_codeblock_indices(content) - # Current index in the original content for tracking purposes - current_index = 0 - # Buffer to store the new content - new_content: list[str] = [] - # Iterate over each code block - for start, end in code_blocks: - # Append the content before this code block - new_content.append(content[current_index:start]) - - # Extract the code block content - code_block_content = content[start:end] - - # Split into lines, process, and rejoin - lines = code_block_content.splitlines(keepends=True) - kept_lines = list[str]() - - i = 0 - while i < len(lines): - if ( - i != 0 - and i + 1 < len(lines) - and lines[i].strip() == "" - and lines[i + 1].lstrip().startswith("#") - ): - i += 2 - continue - if lines[i].lstrip().startswith("#"): - i += 1 - continue - kept_lines.append(lines[i]) - i += 1 - - # Join the processed lines and add to the modified blocks list - modified_block_content = "".join(kept_lines) - new_content.append(modified_block_content) - - # Update current index - current_index = end - - # Add the remaining part of the original content after the last code block - new_content.append(content[current_index:]) - - # Join all parts to form the final modified content - return "".join(new_content) +DEFAULT_TEMPLATE = """\ +### Instruction +{instruction} + +### Response +{response}""" + + +def is_base_model(tokenizer_name: str) -> bool: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + return tokenizer.chat_template is None and "octocoder" not in tokenizer_name + + +OCTOCODER_CHAT_TEMPLATE = """\ +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {{ raise_exception('System messages are not allowed in this template.') }} + {%- else %} + {%- if message['role'] == 'user' %} +{{'Question: ' + message['content'] + '\n\n'}} + {%- else %} +{{'Answer: ' + message['content'] + '\n\n'}} + {%- endif %} + {%- endif %} +{%- endfor %} +{{'Question: '}}""" def infer_prompt_template(tokenizer_name: str) -> str: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - template = tokenizer.apply_chat_template( - [ - {"role": "user", "content": "{instruction}"}, - {"role": "assistant", "content": "{response}"}, - ], - tokenize=False, - ) + if "octocoder" in tokenizer_name: + tokenizer.chat_template = OCTOCODER_CHAT_TEMPLATE + if tokenizer.chat_template is not None: + template = tokenizer.apply_chat_template( + [ + {"role": "user", "content": "{instruction}"}, + {"role": "assistant", "content": "{response}"}, + ], + tokenize=False, + ) + else: + template = DEFAULT_TEMPLATE end_index = template.rindex("{response}") + len("{response}") template = template[:end_index] return template