diff options
Diffstat (limited to '')
| -rw-r--r-- | README.md | 20 | ||||
| -rw-r--r-- | py/chat.py | 3 | ||||
| -rw-r--r-- | py/utils.py | 116 | ||||
| -rw-r--r-- | tests/chat_test.py | 310 | ||||
| -rw-r--r-- | tests/resources/binary_file.bin | 2 | ||||
| -rw-r--r-- | tests/resources/image_file.jpg | 1 |
6 files changed, 345 insertions, 107 deletions
@@ -13,6 +13,7 @@ To get an idea what is possible to do with AI commands see the [prompts](https:/ - Edit selected text in-place with AI - Interactive conversation with ChatGPT - Custom roles +- Vision capabilities (image to text) - Integrates with any OpenAI-compatible API ## How it works @@ -194,7 +195,7 @@ You are a Clean Code expert, I have the following code, please refactor it in a ``` -To include files in the chat a special `include` role is used: +To include files in the chat a special `include` section is used: ``` >>> user @@ -207,9 +208,22 @@ Generate documentation for the following files /home/user/myproject/**/*.py ``` -Each file's contents will be added to an additional `user` role message with the files separated by `==> {path} <==`, where path is the path to the file. Globbing is expanded out via `glob.gob` and relative paths to the current working directory (as determined by `getcwd()`) will be resolved to absolute paths. +Each file's contents will be added to an additional user message with `==> {path} <==` header, relative paths are resolved to the current working directory. -Supported chat roles are **`>>> system`**, **`>>> user`**, **`>>> include`** and **`<<< assistant`** + +To use image vision capabilities (image to text) include an image file: + +``` +>>> user + +What object is on the image? + +>>> include + +~/myimage.jpg +``` + +Supported chat sections are **`>>> system`**, **`>>> user`**, **`>>> include`** and **`<<< assistant`** ### `:AIRedo` @@ -56,7 +56,8 @@ def run_ai_chat(context): messages = initial_messages + chat_messages try: - if messages[-1]["content"].strip(): + last_content = messages[-1]["content"][-1] + if last_content['type'] != 'text' or last_content['text']: vim.command("normal! Go\n<<< assistant\n\n") vim.command("redraw") diff --git a/py/utils.py b/py/utils.py index f5553ee..b2960f2 100644 --- a/py/utils.py +++ b/py/utils.py @@ -12,6 +12,7 @@ from urllib.error import URLError from urllib.error import HTTPError import traceback import configparser +import base64 utils_py_imported = True @@ -109,63 +110,84 @@ def render_text_chunks(chunks): if not full_text.strip(): raise KnownError('Empty response received. Tip: You can try modifying the prompt and retry.') +def encode_image(image_path): + """Encodes an image file to a base64 string.""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') -def parse_chat_messages(chat_content): - lines = chat_content.splitlines() - messages = [] - for line in lines: - if line.startswith(">>> system"): - messages.append({"role": "system", "content": ""}) - continue - if line.startswith(">>> user"): - messages.append({"role": "user", "content": ""}) - continue - if line.startswith(">>> include"): - messages.append({"role": "include", "content": ""}) - continue - if line.startswith("<<< assistant"): - messages.append({"role": "assistant", "content": ""}) - continue - if not messages: - continue - messages[-1]["content"] += "\n" + line - for message in messages: - # strip newlines from the content as it causes empty responses - message["content"] = message["content"].strip() +def is_image_path(path): + ext = path.strip().split('.')[-1] + return ext in ['jpg', 'jpeg', 'png', 'gif'] - if message["role"] == "include": - message["role"] = "user" - paths = message["content"].split("\n") - message["content"] = "" +def parse_include_paths(path): + if not path: + return [] + pwd = vim.eval('getcwd()') - pwd = vim.eval("getcwd()") - for i in range(len(paths)): - path = os.path.expanduser(paths[i]) - if not os.path.isabs(path): - path = os.path.join(pwd, path) + path = os.path.expanduser(path) + if not os.path.isabs(path): + path = os.path.join(pwd, path) - paths[i] = path + expanded_paths = [path] + if '*' in path: + expanded_paths = glob.glob(path, recursive=True) - if '**' in path: - paths[i] = None - paths.extend(glob.glob(path, recursive=True)) + return [path for path in expanded_paths if not os.path.isdir(path)] - for path in paths: - if path is None: - continue +def make_image_message(path): + ext = path.split('.')[-1] + base64_image = encode_image(path) + return { 'type': 'image_url', 'image_url': { 'url': f"data:image/{ext.replace('.', '')};base64,{base64_image}" } } + +def make_text_file_message(path): + try: + with open(path, 'r') as file: + file_content = file.read().strip() + return { 'type': 'text', 'text': f'==> {path} <==\n' + file_content.strip() } + except UnicodeDecodeError: + return { 'type': 'text', 'text': f'==> {path} <==\nBinary file, cannot display' } + +def parse_chat_messages(chat_content): + lines = chat_content.splitlines() + messages = [] - if os.path.isdir(path): + current_type = '' + for line in lines: + match line: + case '>>> system': + messages.append({'role': 'system', 'content': [{ 'type': 'text', 'text': '' }]}) + current_type = 'system' + case '<<< assistant': + messages.append({'role': 'assistant', 'content': [{ 'type': 'text', 'text': '' }]}) + current_type = 'assistant' + case '>>> user': + if messages and messages[-1]['role'] == 'user': + messages[-1]['content'].append({ 'type': 'text', 'text': '' }) + else: + messages.append({'role': 'user', 'content': [{ 'type': 'text', 'text': '' }]}) + current_type = 'user' + case '>>> include': + if not messages or messages[-1]['role'] != 'user': + messages.append({'role': 'user', 'content': []}) + current_type = 'include' + case _: + if not messages: continue + match current_type: + case 'assistant' | 'system' | 'user': + messages[-1]['content'][-1]['text'] += '\n' + line + case 'include': + paths = parse_include_paths(line) + for path in paths: + content = make_image_message(path) if is_image_path(path) else make_text_file_message(path) + messages[-1]['content'].append(content) - try: - with open(path, "r") as file: - file_content = file.read().strip() - message["content"] += f"\n\n==> {path} <==\n" + file_content - except UnicodeDecodeError: - message["content"] += "\n\n" + f"==> {path} <==" - message["content"] += "\n" + "Binary file, cannot display" - message['content'] = message['content'].strip() + for message in messages: + # strip newlines from the text content as it causes empty responses + for content in message['content']: + if content['type'] == 'text': + content['text'] = content['text'].strip() return messages diff --git a/tests/chat_test.py b/tests/chat_test.py index 9acfecb..e29368c 100644 --- a/tests/chat_test.py +++ b/tests/chat_test.py @@ -15,10 +15,18 @@ def test_parse_user_message(): generate lorem ipsum """) - messages = parse_chat_messages(chat_content) - assert 1 == len(messages) - assert 'user' == messages[0]['role'] - assert 'generate lorem ipsum' == messages[0]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'generate lorem ipsum', + }, + ], + }, + ] == actual_messages def test_parse_system_message(): @@ -31,12 +39,56 @@ def test_parse_system_message(): generate lorem ipsum """) - messages = parse_chat_messages(chat_content) - assert 2 == len(messages) - assert 'system' == messages[0]['role'] - assert 'you are general assystant' == messages[0]['content'] - assert 'user' == messages[1]['role'] - assert 'generate lorem ipsum' == messages[1]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'system', + 'content': [ + { + 'type': 'text', + 'text': 'you are general assystant', + }, + ], + }, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'generate lorem ipsum', + }, + ], + }, + ] == actual_messages + + +def test_parse_two_user_messages(): + chat_content = strip_text( + """ + >>> user + + generate lorem ipsum + + >>> user + + in english + """) + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'generate lorem ipsum', + }, + { + 'type': 'text', + 'text': 'in english', + }, + ], + }, + ] == actual_messages def test_parse_assistant_message(): chat_content = strip_text(""" @@ -52,14 +104,36 @@ def test_parse_assistant_message(): again """) - messages = parse_chat_messages(chat_content) - assert 3 == len(messages) - assert 'user' == messages[0]['role'] - assert 'generate lorem ipsum' == messages[0]['content'] - assert 'assistant' == messages[1]['role'] - assert 'bla bla bla' == messages[1]['content'] - assert 'user' == messages[2]['role'] - assert 'again' == messages[2]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'generate lorem ipsum', + }, + ], + }, + { + 'role': 'assistant', + 'content': [ + { + 'type': 'text', + 'text': 'bla bla bla', + }, + ], + }, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'again', + }, + ], + }, + ] == actual_messages def test_parse_include_single_file_message(): chat_content = strip_text(f""" @@ -70,17 +144,50 @@ def test_parse_include_single_file_message(): >>> include {curr_dir}/resources/test1.include.txt + + <<< assistant + + it already is in human language + + >>> user + + try harder """) messages = parse_chat_messages(chat_content) - assert 2 == len(messages) - assert 'user' == messages[0]['role'] - assert 'translate to human language' == messages[0]['content'] - assert 'user' == messages[1]['role'] - expected_content = strip_text(f""" - ==> {curr_dir}/resources/test1.include.txt <== - hello world - """) - assert expected_content == messages[1]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'translate to human language', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world', + }, + ], + }, + { + 'role': 'assistant', + 'content': [ + { + 'type': 'text', + 'text': 'it already is in human language', + }, + ], + }, + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'try harder', + }, + ], + }, + ] == actual_messages def test_parse_include_multiple_files_message(): chat_content = strip_text(f""" @@ -94,18 +201,26 @@ def test_parse_include_multiple_files_message(): {curr_dir}/resources/test2.include.txt """) messages = parse_chat_messages(chat_content) - assert 2 == len(messages) - assert 'user' == messages[0]['role'] - assert 'translate to human language' == messages[0]['content'] - assert 'user' == messages[1]['role'] - expected_content = strip_text(f""" - ==> {curr_dir}/resources/test1.include.txt <== - hello world - - ==> {curr_dir}/resources/test2.include.txt <== - vim is awesome - """) - assert expected_content == messages[1]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'translate to human language', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome', + }, + ], + }, + ] == actual_messages def test_parse_include_glob_files_message(): chat_content = strip_text(f""" @@ -117,24 +232,107 @@ def test_parse_include_glob_files_message(): {curr_dir}/**/*.include.txt """) - messages = parse_chat_messages(chat_content) - assert 2 == len(messages) - assert 'user' == messages[0]['role'] - assert 'translate to human language' == messages[0]['content'] - assert 'user' == messages[1]['role'] - expected_content = strip_text(f""" - ==> {curr_dir}/resources/test1.include.txt <== - hello world - - ==> {curr_dir}/resources/test2.include.txt <== - vim is awesome - """) - assert expected_content == messages[1]['content'] + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'translate to human language', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome', + }, + ], + }, + ] == actual_messages def test_parse_include_image_message(): - # TODO - pass + chat_content = strip_text(f""" + >>> user + + what is on the image? + + >>> include + + {curr_dir}/**/*.jpg + """) + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': 'what is on the image?', + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'data:image/jpg;base64,aW1hZ2UgZGF0YQo=' + }, + }, + ], + }, + ] == actual_messages def test_parse_include_image_with_files_message(): - # TODO - pass + chat_content = strip_text(f""" + >>> include + + {curr_dir}/resources/test1.include.txt + {curr_dir}/resources/image_file.jpg + {curr_dir}/resources/test2.include.txt + """) + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world', + }, + { + 'type': 'image_url', + 'image_url': { + 'url': 'data:image/jpg;base64,aW1hZ2UgZGF0YQo=' + }, + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome', + }, + ], + }, + ] == actual_messages + +def test_parse_include_unsupported_binary_file(): + chat_content = strip_text(f""" + >>> include + + {curr_dir}/resources/binary_file.bin + {curr_dir}/resources/test1.include.txt + """) + actual_messages = parse_chat_messages(chat_content) + assert [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/binary_file.bin <==\nBinary file, cannot display', + }, + { + 'type': 'text', + 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world', + }, + ], + }, + ] == actual_messages diff --git a/tests/resources/binary_file.bin b/tests/resources/binary_file.bin new file mode 100644 index 0000000..acd83ea --- /dev/null +++ b/tests/resources/binary_file.bin @@ -0,0 +1,2 @@ +#n@E[+)W~Eo{ZL>]^ʒZxv"E%,boqqȸsv$)by7¶|SC. !UL3Sbu$cjӧ }D>|E +rc(*(hWGZ?쮘po$B\vz풮ƌ:'=$o6v~p$o5haMSy(9ֺf`Oa|>kL
\ No newline at end of file diff --git a/tests/resources/image_file.jpg b/tests/resources/image_file.jpg new file mode 100644 index 0000000..3433468 --- /dev/null +++ b/tests/resources/image_file.jpg @@ -0,0 +1 @@ +image data |