summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Bielik <martin.bielik@instea.sk>2024-12-21 00:44:35 +0100
committerMartin Bielik <martin.bielik@instea.sk>2024-12-21 00:45:31 +0100
commit2643c4f3e7a637d1c289a2ff3ad582deb11de3c0 (patch)
tree386b6e637cb0f9f756c3efc68862a7e71b63f54d
parent933a90d43ce9e360bb139dda2040b4360b9b12ce (diff)
downloadvim-ai-2643c4f3e7a637d1c289a2ff3ad582deb11de3c0.tar.gz
image to text support, closes #134
Diffstat (limited to '')
-rw-r--r--README.md20
-rw-r--r--py/chat.py3
-rw-r--r--py/utils.py116
-rw-r--r--tests/chat_test.py310
-rw-r--r--tests/resources/binary_file.bin2
-rw-r--r--tests/resources/image_file.jpg1
6 files changed, 345 insertions, 107 deletions
diff --git a/README.md b/README.md
index c3df247..456b904 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,7 @@ To get an idea what is possible to do with AI commands see the [prompts](https:/
- Edit selected text in-place with AI
- Interactive conversation with ChatGPT
- Custom roles
+- Vision capabilities (image to text)
- Integrates with any OpenAI-compatible API
## How it works
@@ -194,7 +195,7 @@ You are a Clean Code expert, I have the following code, please refactor it in a
```
-To include files in the chat a special `include` role is used:
+To include files in the chat a special `include` section is used:
```
>>> user
@@ -207,9 +208,22 @@ Generate documentation for the following files
/home/user/myproject/**/*.py
```
-Each file's contents will be added to an additional `user` role message with the files separated by `==> {path} <==`, where path is the path to the file. Globbing is expanded out via `glob.gob` and relative paths to the current working directory (as determined by `getcwd()`) will be resolved to absolute paths.
+Each file's contents will be added to an additional user message with `==> {path} <==` header, relative paths are resolved to the current working directory.
-Supported chat roles are **`>>> system`**, **`>>> user`**, **`>>> include`** and **`<<< assistant`**
+
+To use image vision capabilities (image to text) include an image file:
+
+```
+>>> user
+
+What object is on the image?
+
+>>> include
+
+~/myimage.jpg
+```
+
+Supported chat sections are **`>>> system`**, **`>>> user`**, **`>>> include`** and **`<<< assistant`**
### `:AIRedo`
diff --git a/py/chat.py b/py/chat.py
index 79457ee..6861b51 100644
--- a/py/chat.py
+++ b/py/chat.py
@@ -56,7 +56,8 @@ def run_ai_chat(context):
messages = initial_messages + chat_messages
try:
- if messages[-1]["content"].strip():
+ last_content = messages[-1]["content"][-1]
+ if last_content['type'] != 'text' or last_content['text']:
vim.command("normal! Go\n<<< assistant\n\n")
vim.command("redraw")
diff --git a/py/utils.py b/py/utils.py
index f5553ee..b2960f2 100644
--- a/py/utils.py
+++ b/py/utils.py
@@ -12,6 +12,7 @@ from urllib.error import URLError
from urllib.error import HTTPError
import traceback
import configparser
+import base64
utils_py_imported = True
@@ -109,63 +110,84 @@ def render_text_chunks(chunks):
if not full_text.strip():
raise KnownError('Empty response received. Tip: You can try modifying the prompt and retry.')
+def encode_image(image_path):
+ """Encodes an image file to a base64 string."""
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode('utf-8')
-def parse_chat_messages(chat_content):
- lines = chat_content.splitlines()
- messages = []
- for line in lines:
- if line.startswith(">>> system"):
- messages.append({"role": "system", "content": ""})
- continue
- if line.startswith(">>> user"):
- messages.append({"role": "user", "content": ""})
- continue
- if line.startswith(">>> include"):
- messages.append({"role": "include", "content": ""})
- continue
- if line.startswith("<<< assistant"):
- messages.append({"role": "assistant", "content": ""})
- continue
- if not messages:
- continue
- messages[-1]["content"] += "\n" + line
- for message in messages:
- # strip newlines from the content as it causes empty responses
- message["content"] = message["content"].strip()
+def is_image_path(path):
+ ext = path.strip().split('.')[-1]
+ return ext in ['jpg', 'jpeg', 'png', 'gif']
- if message["role"] == "include":
- message["role"] = "user"
- paths = message["content"].split("\n")
- message["content"] = ""
+def parse_include_paths(path):
+ if not path:
+ return []
+ pwd = vim.eval('getcwd()')
- pwd = vim.eval("getcwd()")
- for i in range(len(paths)):
- path = os.path.expanduser(paths[i])
- if not os.path.isabs(path):
- path = os.path.join(pwd, path)
+ path = os.path.expanduser(path)
+ if not os.path.isabs(path):
+ path = os.path.join(pwd, path)
- paths[i] = path
+ expanded_paths = [path]
+ if '*' in path:
+ expanded_paths = glob.glob(path, recursive=True)
- if '**' in path:
- paths[i] = None
- paths.extend(glob.glob(path, recursive=True))
+ return [path for path in expanded_paths if not os.path.isdir(path)]
- for path in paths:
- if path is None:
- continue
+def make_image_message(path):
+ ext = path.split('.')[-1]
+ base64_image = encode_image(path)
+ return { 'type': 'image_url', 'image_url': { 'url': f"data:image/{ext.replace('.', '')};base64,{base64_image}" } }
+
+def make_text_file_message(path):
+ try:
+ with open(path, 'r') as file:
+ file_content = file.read().strip()
+ return { 'type': 'text', 'text': f'==> {path} <==\n' + file_content.strip() }
+ except UnicodeDecodeError:
+ return { 'type': 'text', 'text': f'==> {path} <==\nBinary file, cannot display' }
+
+def parse_chat_messages(chat_content):
+ lines = chat_content.splitlines()
+ messages = []
- if os.path.isdir(path):
+ current_type = ''
+ for line in lines:
+ match line:
+ case '>>> system':
+ messages.append({'role': 'system', 'content': [{ 'type': 'text', 'text': '' }]})
+ current_type = 'system'
+ case '<<< assistant':
+ messages.append({'role': 'assistant', 'content': [{ 'type': 'text', 'text': '' }]})
+ current_type = 'assistant'
+ case '>>> user':
+ if messages and messages[-1]['role'] == 'user':
+ messages[-1]['content'].append({ 'type': 'text', 'text': '' })
+ else:
+ messages.append({'role': 'user', 'content': [{ 'type': 'text', 'text': '' }]})
+ current_type = 'user'
+ case '>>> include':
+ if not messages or messages[-1]['role'] != 'user':
+ messages.append({'role': 'user', 'content': []})
+ current_type = 'include'
+ case _:
+ if not messages:
continue
+ match current_type:
+ case 'assistant' | 'system' | 'user':
+ messages[-1]['content'][-1]['text'] += '\n' + line
+ case 'include':
+ paths = parse_include_paths(line)
+ for path in paths:
+ content = make_image_message(path) if is_image_path(path) else make_text_file_message(path)
+ messages[-1]['content'].append(content)
- try:
- with open(path, "r") as file:
- file_content = file.read().strip()
- message["content"] += f"\n\n==> {path} <==\n" + file_content
- except UnicodeDecodeError:
- message["content"] += "\n\n" + f"==> {path} <=="
- message["content"] += "\n" + "Binary file, cannot display"
- message['content'] = message['content'].strip()
+ for message in messages:
+ # strip newlines from the text content as it causes empty responses
+ for content in message['content']:
+ if content['type'] == 'text':
+ content['text'] = content['text'].strip()
return messages
diff --git a/tests/chat_test.py b/tests/chat_test.py
index 9acfecb..e29368c 100644
--- a/tests/chat_test.py
+++ b/tests/chat_test.py
@@ -15,10 +15,18 @@ def test_parse_user_message():
generate lorem ipsum
""")
- messages = parse_chat_messages(chat_content)
- assert 1 == len(messages)
- assert 'user' == messages[0]['role']
- assert 'generate lorem ipsum' == messages[0]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'generate lorem ipsum',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_system_message():
@@ -31,12 +39,56 @@ def test_parse_system_message():
generate lorem ipsum
""")
- messages = parse_chat_messages(chat_content)
- assert 2 == len(messages)
- assert 'system' == messages[0]['role']
- assert 'you are general assystant' == messages[0]['content']
- assert 'user' == messages[1]['role']
- assert 'generate lorem ipsum' == messages[1]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'system',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'you are general assystant',
+ },
+ ],
+ },
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'generate lorem ipsum',
+ },
+ ],
+ },
+ ] == actual_messages
+
+
+def test_parse_two_user_messages():
+ chat_content = strip_text(
+ """
+ >>> user
+
+ generate lorem ipsum
+
+ >>> user
+
+ in english
+ """)
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'generate lorem ipsum',
+ },
+ {
+ 'type': 'text',
+ 'text': 'in english',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_assistant_message():
chat_content = strip_text("""
@@ -52,14 +104,36 @@ def test_parse_assistant_message():
again
""")
- messages = parse_chat_messages(chat_content)
- assert 3 == len(messages)
- assert 'user' == messages[0]['role']
- assert 'generate lorem ipsum' == messages[0]['content']
- assert 'assistant' == messages[1]['role']
- assert 'bla bla bla' == messages[1]['content']
- assert 'user' == messages[2]['role']
- assert 'again' == messages[2]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'generate lorem ipsum',
+ },
+ ],
+ },
+ {
+ 'role': 'assistant',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'bla bla bla',
+ },
+ ],
+ },
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'again',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_include_single_file_message():
chat_content = strip_text(f"""
@@ -70,17 +144,50 @@ def test_parse_include_single_file_message():
>>> include
{curr_dir}/resources/test1.include.txt
+
+ <<< assistant
+
+ it already is in human language
+
+ >>> user
+
+ try harder
""")
messages = parse_chat_messages(chat_content)
- assert 2 == len(messages)
- assert 'user' == messages[0]['role']
- assert 'translate to human language' == messages[0]['content']
- assert 'user' == messages[1]['role']
- expected_content = strip_text(f"""
- ==> {curr_dir}/resources/test1.include.txt <==
- hello world
- """)
- assert expected_content == messages[1]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'translate to human language',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world',
+ },
+ ],
+ },
+ {
+ 'role': 'assistant',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'it already is in human language',
+ },
+ ],
+ },
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'try harder',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_include_multiple_files_message():
chat_content = strip_text(f"""
@@ -94,18 +201,26 @@ def test_parse_include_multiple_files_message():
{curr_dir}/resources/test2.include.txt
""")
messages = parse_chat_messages(chat_content)
- assert 2 == len(messages)
- assert 'user' == messages[0]['role']
- assert 'translate to human language' == messages[0]['content']
- assert 'user' == messages[1]['role']
- expected_content = strip_text(f"""
- ==> {curr_dir}/resources/test1.include.txt <==
- hello world
-
- ==> {curr_dir}/resources/test2.include.txt <==
- vim is awesome
- """)
- assert expected_content == messages[1]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'translate to human language',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_include_glob_files_message():
chat_content = strip_text(f"""
@@ -117,24 +232,107 @@ def test_parse_include_glob_files_message():
{curr_dir}/**/*.include.txt
""")
- messages = parse_chat_messages(chat_content)
- assert 2 == len(messages)
- assert 'user' == messages[0]['role']
- assert 'translate to human language' == messages[0]['content']
- assert 'user' == messages[1]['role']
- expected_content = strip_text(f"""
- ==> {curr_dir}/resources/test1.include.txt <==
- hello world
-
- ==> {curr_dir}/resources/test2.include.txt <==
- vim is awesome
- """)
- assert expected_content == messages[1]['content']
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'translate to human language',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome',
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_include_image_message():
- # TODO
- pass
+ chat_content = strip_text(f"""
+ >>> user
+
+ what is on the image?
+
+ >>> include
+
+ {curr_dir}/**/*.jpg
+ """)
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': 'what is on the image?',
+ },
+ {
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': 'data:image/jpg;base64,aW1hZ2UgZGF0YQo='
+ },
+ },
+ ],
+ },
+ ] == actual_messages
def test_parse_include_image_with_files_message():
- # TODO
- pass
+ chat_content = strip_text(f"""
+ >>> include
+
+ {curr_dir}/resources/test1.include.txt
+ {curr_dir}/resources/image_file.jpg
+ {curr_dir}/resources/test2.include.txt
+ """)
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world',
+ },
+ {
+ 'type': 'image_url',
+ 'image_url': {
+ 'url': 'data:image/jpg;base64,aW1hZ2UgZGF0YQo='
+ },
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test2.include.txt <==\nvim is awesome',
+ },
+ ],
+ },
+ ] == actual_messages
+
+def test_parse_include_unsupported_binary_file():
+ chat_content = strip_text(f"""
+ >>> include
+
+ {curr_dir}/resources/binary_file.bin
+ {curr_dir}/resources/test1.include.txt
+ """)
+ actual_messages = parse_chat_messages(chat_content)
+ assert [
+ {
+ 'role': 'user',
+ 'content': [
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/binary_file.bin <==\nBinary file, cannot display',
+ },
+ {
+ 'type': 'text',
+ 'text': f'==> {curr_dir}/resources/test1.include.txt <==\nhello world',
+ },
+ ],
+ },
+ ] == actual_messages
diff --git a/tests/resources/binary_file.bin b/tests/resources/binary_file.bin
new file mode 100644
index 0000000..acd83ea
--- /dev/null
+++ b/tests/resources/binary_file.bin
@@ -0,0 +1,2 @@
+#n@E[+)W~Eo{ZL>]^ʒZxv"E%,boqqȸs v$)by7¶|SC . !UL3Sbu$cjӧ }D>|E
+rc(*(h WGZ?쮘po$B\vz풮ƌ:'=$o6v~p$o5haMSy(9ֺf`Oa|>kL \ No newline at end of file
diff --git a/tests/resources/image_file.jpg b/tests/resources/image_file.jpg
new file mode 100644
index 0000000..3433468
--- /dev/null
+++ b/tests/resources/image_file.jpg
@@ -0,0 +1 @@
+image data