summaryrefslogtreecommitdiff
path: root/py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--py/context.py5
-rw-r--r--py/image.py50
-rw-r--r--py/roles.py9
-rw-r--r--py/utils.py7
4 files changed, 66 insertions, 5 deletions
diff --git a/py/context.py b/py/context.py
index 581f8ad..343a050 100644
--- a/py/context.py
+++ b/py/context.py
@@ -79,7 +79,7 @@ def load_role_config(role):
enhance_roles_with_custom_function(roles)
- postfixes = ["", ".complete", ".edit", ".chat"]
+ postfixes = ["", ".complete", ".edit", ".chat", ".image"]
if not any([f"{role}{postfix}" in roles for postfix in postfixes]):
raise Exception(f"Role `{role}` not found")
@@ -91,6 +91,7 @@ def load_role_config(role):
'role_complete': parse_role_section(roles.get(f"{role}.complete", {})),
'role_edit': parse_role_section(roles.get(f"{role}.edit", {})),
'role_chat': parse_role_section(roles.get(f"{role}.chat", {})),
+ 'role_image': parse_role_section(roles.get(f"{role}.image", {})),
}
def parse_role_names(prompt):
@@ -147,7 +148,7 @@ def make_ai_context(params):
user_prompt, role_config = parse_prompt_and_role_config(user_instruction, command_type)
final_config = merge_deep([config_default, config_extension, role_config])
- selection_boundary = final_config['options']['selection_boundary']
+ selection_boundary = final_config['options'].get('selection_boundary', '')
config_prompt = final_config.get('prompt', '')
prompt = make_prompt(config_prompt, user_prompt, user_selection, selection_boundary)
diff --git a/py/image.py b/py/image.py
new file mode 100644
index 0000000..4af7b6b
--- /dev/null
+++ b/py/image.py
@@ -0,0 +1,50 @@
+import vim
+import datetime
+import os
+
+image_py_imported = True
+
+def make_openai_image_options(options):
+ return {
+ 'model': options['model'],
+ 'quality': 'standard',
+ 'size': '1024x1024',
+ 'style': 'vivid',
+ 'response_format': 'b64_json',
+ }
+
+def make_image_path(ui):
+ download_dir = ui.get('download_dir', vim.eval('getcwd()'))
+ timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%dT%H%M%SZ")
+ filename = f'vim_ai_{timestamp}.png'
+ return os.path.join(download_dir, filename)
+
+def run_ai_image(context):
+ prompt = context['prompt']
+ config = context['config']
+ config_options = config['options']
+ ui = config['ui']
+
+ try:
+ if prompt:
+ print('Generating...')
+ openai_options = make_openai_image_options(config_options)
+ http_options = make_http_options(config_options)
+ request = { 'prompt': prompt, **openai_options }
+
+ print_debug("[image] text:\n" + prompt)
+ print_debug("[image] request: {}", request)
+ url = config_options['endpoint_url']
+
+ response, *_ = openai_request(url, request, http_options)
+ print_debug("[image] response: {}", { 'images_count': len(response['data']) })
+
+ path = make_image_path(ui)
+ b64_data = response['data'][0]['b64_json']
+ save_b64_to_file(path, b64_data)
+
+ clear_echo_message()
+ print(f"Image: {path}")
+ except BaseException as error:
+ handle_completion_error(error)
+ print_debug("[image] error: {}", traceback.format_exc())
diff --git a/py/roles.py b/py/roles.py
index bb5356e..7f038b1 100644
--- a/py/roles.py
+++ b/py/roles.py
@@ -13,8 +13,13 @@ def load_ai_role_names(command_type):
role_names = set()
for name in roles.sections():
parts = name.split('.')
- if len(parts) == 1 or parts[-1] == command_type:
- role_names.add(parts[0])
+ if command_type == 'image':
+ # special case - image type have to be explicitely defined
+ if len(parts) > 1 and parts[-1] == command_type:
+ role_names.add(parts[0])
+ else:
+ if len(parts) == 1 or parts[-1] == command_type:
+ role_names.add(parts[0])
role_names = [name for name in role_names if name != DEFAULT_ROLE_NAME]
diff --git a/py/utils.py b/py/utils.py
index d118326..8cd204f 100644
--- a/py/utils.py
+++ b/py/utils.py
@@ -250,7 +250,7 @@ def openai_request(url, data, options):
)
with urllib.request.urlopen(req, timeout=request_timeout) as response:
- if not data['stream']:
+ if not data.get('stream', 0):
yield json.loads(response.read().decode())
return
for line_bytes in response:
@@ -354,3 +354,8 @@ def read_role_files():
roles = configparser.ConfigParser()
roles.read([default_roles_config_path, roles_config_path])
return roles
+
+def save_b64_to_file(path, b64_data):
+ f = open(path, "wb")
+ f.write(base64.b64decode(b64_data))
+ f.close()