From fbc2bfb445c71985e9fc399d3fac2def2fc6854e Mon Sep 17 00:00:00 2001 From: Martin Bielik Date: Sun, 22 Dec 2024 14:55:15 +0100 Subject: added image generation --- py/context.py | 5 +++-- py/image.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ py/roles.py | 9 +++++++-- py/utils.py | 7 ++++++- 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 py/image.py (limited to 'py') diff --git a/py/context.py b/py/context.py index 581f8ad..343a050 100644 --- a/py/context.py +++ b/py/context.py @@ -79,7 +79,7 @@ def load_role_config(role): enhance_roles_with_custom_function(roles) - postfixes = ["", ".complete", ".edit", ".chat"] + postfixes = ["", ".complete", ".edit", ".chat", ".image"] if not any([f"{role}{postfix}" in roles for postfix in postfixes]): raise Exception(f"Role `{role}` not found") @@ -91,6 +91,7 @@ def load_role_config(role): 'role_complete': parse_role_section(roles.get(f"{role}.complete", {})), 'role_edit': parse_role_section(roles.get(f"{role}.edit", {})), 'role_chat': parse_role_section(roles.get(f"{role}.chat", {})), + 'role_image': parse_role_section(roles.get(f"{role}.image", {})), } def parse_role_names(prompt): @@ -147,7 +148,7 @@ def make_ai_context(params): user_prompt, role_config = parse_prompt_and_role_config(user_instruction, command_type) final_config = merge_deep([config_default, config_extension, role_config]) - selection_boundary = final_config['options']['selection_boundary'] + selection_boundary = final_config['options'].get('selection_boundary', '') config_prompt = final_config.get('prompt', '') prompt = make_prompt(config_prompt, user_prompt, user_selection, selection_boundary) diff --git a/py/image.py b/py/image.py new file mode 100644 index 0000000..4af7b6b --- /dev/null +++ b/py/image.py @@ -0,0 +1,50 @@ +import vim +import datetime +import os + +image_py_imported = True + +def make_openai_image_options(options): + return { + 'model': options['model'], + 'quality': 'standard', + 'size': '1024x1024', + 'style': 'vivid', + 'response_format': 'b64_json', + } + +def make_image_path(ui): + download_dir = ui.get('download_dir', vim.eval('getcwd()')) + timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%dT%H%M%SZ") + filename = f'vim_ai_{timestamp}.png' + return os.path.join(download_dir, filename) + +def run_ai_image(context): + prompt = context['prompt'] + config = context['config'] + config_options = config['options'] + ui = config['ui'] + + try: + if prompt: + print('Generating...') + openai_options = make_openai_image_options(config_options) + http_options = make_http_options(config_options) + request = { 'prompt': prompt, **openai_options } + + print_debug("[image] text:\n" + prompt) + print_debug("[image] request: {}", request) + url = config_options['endpoint_url'] + + response, *_ = openai_request(url, request, http_options) + print_debug("[image] response: {}", { 'images_count': len(response['data']) }) + + path = make_image_path(ui) + b64_data = response['data'][0]['b64_json'] + save_b64_to_file(path, b64_data) + + clear_echo_message() + print(f"Image: {path}") + except BaseException as error: + handle_completion_error(error) + print_debug("[image] error: {}", traceback.format_exc()) diff --git a/py/roles.py b/py/roles.py index bb5356e..7f038b1 100644 --- a/py/roles.py +++ b/py/roles.py @@ -13,8 +13,13 @@ def load_ai_role_names(command_type): role_names = set() for name in roles.sections(): parts = name.split('.') - if len(parts) == 1 or parts[-1] == command_type: - role_names.add(parts[0]) + if command_type == 'image': + # special case - image type have to be explicitely defined + if len(parts) > 1 and parts[-1] == command_type: + role_names.add(parts[0]) + else: + if len(parts) == 1 or parts[-1] == command_type: + role_names.add(parts[0]) role_names = [name for name in role_names if name != DEFAULT_ROLE_NAME] diff --git a/py/utils.py b/py/utils.py index d118326..8cd204f 100644 --- a/py/utils.py +++ b/py/utils.py @@ -250,7 +250,7 @@ def openai_request(url, data, options): ) with urllib.request.urlopen(req, timeout=request_timeout) as response: - if not data['stream']: + if not data.get('stream', 0): yield json.loads(response.read().decode()) return for line_bytes in response: @@ -354,3 +354,8 @@ def read_role_files(): roles = configparser.ConfigParser() roles.read([default_roles_config_path, roles_config_path]) return roles + +def save_b64_to_file(path, b64_data): + f = open(path, "wb") + f.write(base64.b64decode(b64_data)) + f.close() -- cgit v1.2.3