Skip to content

Commit

Permalink
feature: added image creation with Dall-E
Browse files Browse the repository at this point in the history
  • Loading branch information
DevEmperor committed Feb 13, 2024
1 parent 276695a commit dca6af5
Show file tree
Hide file tree
Showing 30 changed files with 1,451 additions and 48 deletions.
5 changes: 4 additions & 1 deletion app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ dependencies {
implementation 'androidx.core:core-splashscreen:1.0.1'
implementation 'androidx.wear.watchface:watchface-complications-data-source-ktx:1.1.1'

implementation 'com.theokanning.openai-gpt3-java:service:0.16.1'
implementation 'com.theokanning.openai-gpt3-java:service:0.18.2'
implementation 'com.squareup.retrofit2:retrofit:2.9.0'
implementation 'com.squareup.retrofit2:converter-jackson:2.9.0'
implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.2'
implementation 'commons-validator:commons-validator:1.7'
implementation 'com.jsibbold:zoomage:1.3.1'
implementation 'com.github.kenglxn.QRGen:android:3.0.1'
implementation 'com.squareup.picasso:picasso:2.8'

implementation 'io.noties.markwon:core:4.6.2'
implementation 'io.noties.markwon:ext-strikethrough:4.6.2'
Expand Down
14 changes: 12 additions & 2 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
<meta-data
android:name="android.support.wearable.complications.SUPPORTED_TYPES"
android:value="SMALL_IMAGE,MONOCHROMATIC_IMAGE,LONG_TEXT" />

<meta-data
android:name="android.support.wearable.complications.UPDATE_PERIOD_SECONDS"
android:value="0" />
Expand All @@ -50,6 +49,18 @@
<activity
android:name=".activities.EditChatActivity"
android:exported="false" />
<activity
android:name=".activities.ImageActivity"
android:exported="false" />
<activity
android:name=".activities.CreateImageActivity"
android:exported="false" />
<activity
android:name=".activities.OpenImageActivity"
android:exported="false" />
<activity
android:name=".activities.QRCodeActivity"
android:exported="false" />
<activity
android:name=".activities.UsageActivity"
android:exported="false" />
Expand All @@ -71,7 +82,6 @@
<activity
android:name=".activities.MainActivity"
android:exported="true"
android:launchMode="singleInstance"
android:theme="@style/Theme.Starting">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ private void query(String query) throws JSONException, IOException {
Usage usage = result.getUsage();
ChatItem assistantItem = new ChatItem(answer, usage.getTotalTokens());

usageDatabaseHelper.edit(finalModel, usage.getTotalTokens(), Util.calcCost(finalModel, usage.getPromptTokens(), usage.getCompletionTokens()));
usageDatabaseHelper.edit(finalModel, usage.getTotalTokens(), Util.calcCostChat(finalModel, usage.getPromptTokens(), usage.getCompletionTokens()));

if (Thread.interrupted()) {
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
package net.devemperor.wristassist.activities;

import static com.theokanning.openai.service.OpenAiService.defaultClient;
import static com.theokanning.openai.service.OpenAiService.defaultObjectMapper;

import android.content.Intent;
import android.content.SharedPreferences;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.os.VibrationEffect;
import android.os.Vibrator;
import android.view.View;
import android.widget.ImageButton;
import android.widget.ProgressBar;
import android.widget.ScrollView;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;
import androidx.constraintlayout.widget.ConstraintLayout;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.firebase.crashlytics.FirebaseCrashlytics;
import com.jsibbold.zoomage.ZoomageView;
import com.theokanning.openai.client.OpenAiApi;
import com.theokanning.openai.image.CreateImageRequest;
import com.theokanning.openai.image.Image;
import com.theokanning.openai.image.ImageResult;
import com.theokanning.openai.service.OpenAiService;

import net.devemperor.wristassist.R;
import net.devemperor.wristassist.database.ImageModel;
import net.devemperor.wristassist.database.ImagesDatabaseHelper;
import net.devemperor.wristassist.database.UsageDatabaseHelper;
import net.devemperor.wristassist.util.Util;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
import java.util.Objects;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;

public class CreateImageActivity extends AppCompatActivity {

SharedPreferences sp;
UsageDatabaseHelper usageDatabaseHelper;
ImagesDatabaseHelper imagesDatabaseHelper;
OpenAiService service;
Vibrator vibrator;

ScrollView createImageSv;
ProgressBar imagePb;
TextView errorTv;
ImageButton retryBtn;
ZoomageView imageView;
ImageButton shareBtn;
TextView expiresInTv;
ConstraintLayout saveDiscardBtns;

String prompt;
String model;
String quality;
String style;
String size;
ImageResult imageResult;
Image image;
Bitmap bitmap;
ExecutorService thread;
Timer timer = new Timer();

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_create_image);

sp = getSharedPreferences("net.devemperor.wristassist", MODE_PRIVATE);
imagesDatabaseHelper = new ImagesDatabaseHelper(this);
usageDatabaseHelper = new UsageDatabaseHelper(this);

String apiKey = sp.getString("net.devemperor.wristassist.api_key", "noApiKey");
String apiHost = sp.getString("net.devemperor.wristassist.custom_server_host", "https://api.openai.com/");
ObjectMapper mapper = defaultObjectMapper(); // replaces all control chars (#10 @ GH)
OkHttpClient client = defaultClient(apiKey.replaceAll("[^ -~]", ""), Duration.ofSeconds(120)).newBuilder().build();
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(apiHost)
.client(client)
.addConverterFactory(JacksonConverterFactory.create(mapper))
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build();
OpenAiApi api = retrofit.create(OpenAiApi.class);

service = new OpenAiService(api);
vibrator = (Vibrator) getSystemService(VIBRATOR_SERVICE);

createImageSv = findViewById(R.id.create_image_sv);
imagePb = findViewById(R.id.image_pb);
errorTv = findViewById(R.id.error_image_tv);
retryBtn = findViewById(R.id.retry_image_btn);
imageView = findViewById(R.id.create_image_iv);
shareBtn = findViewById(R.id.share_image_btn);
expiresInTv = findViewById(R.id.expires_image_tv);
saveDiscardBtns = findViewById(R.id.save_discard_image_btns);

prompt = getIntent().getStringExtra("net.devemperor.wristassist.prompt");
model = sp.getBoolean("net.devemperor.wristassist.image_model", false) ? "dall-e-3" : "dall-e-2";
quality = sp.getBoolean("net.devemperor.wristassist.image_quality", false) ? "hd" : "standard";
style = sp.getBoolean("net.devemperor.wristassist.image_style", false) ? "natural" : "vivid";
size = sp.getBoolean("net.devemperor.wristassist.image_model", false) ? "1024x1024" : sp.getString("net.devemperor.wristassist.image_size", "1024x1024");

createAndDownloadImage();
createImageSv.requestFocus();
}

@Override
protected void onDestroy() {
super.onDestroy();
timer.cancel();
if (thread != null) {
thread.shutdownNow();
}
}

private void createAndDownloadImage() {
imagePb.setVisibility(View.VISIBLE);
errorTv.setVisibility(View.GONE);
retryBtn.setVisibility(View.GONE);

thread = Executors.newSingleThreadExecutor();
thread.execute(() -> {
try {
CreateImageRequest cir = CreateImageRequest.builder()
.responseFormat("url")
.n(1)
.prompt(prompt)
.model(model)
.quality(quality)
.size(size)
.style(style)
.build();
imageResult = service.createImage(cir);
image = imageResult.getData().get(0);

timer.scheduleAtFixedRate(new TimerTask() {
@Override
public void run() {
long minutes = (imageResult.getCreated()*1000 + 60*60*1000 - System.currentTimeMillis()) / 60 / 1000;
runOnUiThread(() -> {
if (minutes <= 0) {
expiresInTv.setVisibility(View.GONE);
shareBtn.setVisibility(View.GONE);
timer.cancel();
} else {
expiresInTv.setText(getString(R.string.wristassist_image_expires_in, minutes));
}
});
}
}, 0, 60*1000);

usageDatabaseHelper.edit(model, 1, Util.calcCostImage(model, quality, size));

OkHttpClient downloadClient = new OkHttpClient();
Request request = new Request.Builder().url(image.getUrl()).build();

Response response = downloadClient.newCall(request).execute();
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}

assert response.body() != null;
InputStream inputStream = response.body().byteStream();
bitmap = BitmapFactory.decodeStream(inputStream);
if (bitmap == null) {
throw new IOException("Bitmap is null");
} else {
runOnUiThread(() -> {
if (sp.getBoolean("net.devemperor.wristassist.vibrate", true)) {
vibrator.vibrate(VibrationEffect.createOneShot(300, VibrationEffect.DEFAULT_AMPLITUDE));
}

imageView.setImageBitmap(bitmap);
imagePb.setVisibility(View.GONE);
imageView.setVisibility(View.VISIBLE);
shareBtn.setVisibility(View.VISIBLE);
expiresInTv.setVisibility(View.VISIBLE);
saveDiscardBtns.setVisibility(View.VISIBLE);
});
}
} catch (RuntimeException | IOException e) {
FirebaseCrashlytics fc = FirebaseCrashlytics.getInstance();
fc.setCustomKey("settings", sp.getAll().toString());
fc.setUserId(sp.getString("net.devemperor.wristassist.userid", "null"));
fc.recordException(e);
fc.sendUnsentReports();

e.printStackTrace();
runOnUiThread(() -> {
imagePb.setVisibility(View.GONE);
errorTv.setVisibility(View.VISIBLE);
retryBtn.setVisibility(View.VISIBLE);
timer.cancel();

if (sp.getBoolean("net.devemperor.wristassist.vibrate", true)) {
vibrator.vibrate(VibrationEffect.createWaveform(new long[]{50, 50, 50, 50, 50}, new int[]{-1, 0, -1, 0, -1}, -1));
}

if (Objects.requireNonNull(e.getMessage()).contains("SocketTimeoutException")) {
errorTv.setText(R.string.wristassist_timeout);
} else if (e.getMessage().contains("API key")) {
errorTv.setText(getString(R.string.wristassist_invalid_api_key_message));
} else if (e.getMessage().contains("rejected")) {
errorTv.setText(R.string.wristassist_image_request_rejected);
} else if (e.getMessage().contains("quota") || e.getMessage().contains("limit")) {
errorTv.setText(R.string.wristassist_quota_exceeded);
} else if (e.getMessage().contains("does not exist")) {
errorTv.setText(R.string.wristassist_no_access);
} else {
errorTv.setText(R.string.wristassist_no_internet);
}
});
}
});
}

public void retry(View view) {
createAndDownloadImage();
}

public void shareImage(View view) {
Intent intent = new Intent(this, QRCodeActivity.class);
intent.putExtra("net.devemperor.wristassist.image_url", image.getUrl());
startActivity(intent);
}

public void saveImage(View view) {
ImageModel imageModel;
if (model.equals("dall-e-3")) {
imageModel = new ImageModel(-1, prompt, image.getRevisedPrompt(), model, quality, size, style, imageResult.getCreated() * 1000, image.getUrl());
} else {
imageModel = new ImageModel(-1, prompt, null, model, null, size, null, imageResult.getCreated() * 1000, image.getUrl());
}
int id = imagesDatabaseHelper.add(imageModel);

try {
FileOutputStream out = openFileOutput("image_" + id + ".png", MODE_PRIVATE);
bitmap.compress(Bitmap.CompressFormat.PNG, 90, out);
out.flush();
out.close();
} catch (IOException e) {
e.printStackTrace();
}
timer.cancel();

Intent data = new Intent();
data.putExtra("net.devemperor.wristassist.imageId", id);
setResult(RESULT_OK, data);
finish();
}

public void discardImage(View view) {
timer.cancel();
finish();
}
}
Loading

0 comments on commit dca6af5

Please sign in to comment.