Skip to content

Commit

Permalink
feature: usage activity
Browse files Browse the repository at this point in the history
  • Loading branch information
DevEmperor committed Jan 30, 2024
1 parent 0f0668f commit 4cd892d
Show file tree
Hide file tree
Showing 17 changed files with 382 additions and 51 deletions.
4 changes: 4 additions & 0 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
<activity
android:name=".activities.EditChatActivity"
android:exported="false" />
<activity
android:name=".activities.UsageActivity"
android:exported="false" />
<activity
android:name=".activities.SettingsActivity"
android:exported="false" />
Expand All @@ -68,6 +71,7 @@
<activity
android:name=".activities.MainActivity"
android:exported="true"
android:launchMode="singleInstance"
android:theme="@style/Theme.Starting">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package net.devemperor.wristassist.activities;

import android.app.Activity;
import android.content.SharedPreferences;
import android.os.Bundle;
import android.widget.ImageView;
import android.widget.TextView;
Expand All @@ -11,15 +10,9 @@
import net.devemperor.wristassist.R;
import net.devemperor.wristassist.util.Util;

import java.text.DecimalFormat;


public class AboutActivity extends Activity {

DecimalFormat df = new DecimalFormat("#.#");
TextView totalCost;
SharedPreferences sp;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand All @@ -29,26 +22,11 @@ protected void onCreate(Bundle savedInstanceState) {
aboutText.setText(getString(R.string.wristassist_about, BuildConfig.VERSION_NAME));
aboutText.setTextSize(16 * Util.getFontMultiplier(this));

totalCost = findViewById(R.id.total_cost_tv);
totalCost.setTextSize(16 * Util.getFontMultiplier(this));
sp = getSharedPreferences("net.devemperor.wristassist", MODE_PRIVATE);
refreshTotalCostTv();

totalCost.setOnLongClickListener(v -> {
sp.edit().putLong("net.devemperor.wristassist.total_tokens", 0).apply();
Toast.makeText(v.getContext(), R.string.wristassist_reset_cost, Toast.LENGTH_SHORT).show();
refreshTotalCostTv();
return true;
});

ImageView icon = findViewById(R.id.icon);
icon.setOnLongClickListener(v -> {
Toast.makeText(v.getContext(), sp.getString("net.devemperor.wristassist.userid", "null"), Toast.LENGTH_LONG).show();
Toast.makeText(v.getContext(), getSharedPreferences("net.devemperor.wristassist", MODE_PRIVATE)
.getString("net.devemperor.wristassist.userid", "null"), Toast.LENGTH_LONG).show();
return true;
});
}

private void refreshTotalCostTv() {
totalCost.setText(getString(R.string.wristassist_total_cost, df.format(sp.getLong("net.devemperor.wristassist.total_tokens", 0) / 1000.0)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.firebase.crashlytics.FirebaseCrashlytics;
import com.theokanning.openai.Usage;
import com.theokanning.openai.client.OpenAiApi;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
Expand All @@ -32,7 +33,9 @@
import net.devemperor.wristassist.adapters.ChatAdapter;
import net.devemperor.wristassist.database.ChatHistoryDatabaseHelper;
import net.devemperor.wristassist.database.ChatHistoryModel;
import net.devemperor.wristassist.database.UsageDatabaseHelper;
import net.devemperor.wristassist.items.ChatItem;
import net.devemperor.wristassist.util.Util;

import org.json.JSONArray;
import org.json.JSONException;
Expand Down Expand Up @@ -68,6 +71,7 @@ public class ChatActivity extends Activity {
Vibrator vibrator;

ChatHistoryDatabaseHelper chatHistoryDatabaseHelper;
UsageDatabaseHelper usageDatabaseHelper;
SharedPreferences sp;

boolean firstAnswerComplete = false;
Expand Down Expand Up @@ -96,6 +100,7 @@ protected void onCreate(Bundle savedInstanceState) {
vibrator = (Vibrator) getSystemService(VIBRATOR_SERVICE);

chatHistoryDatabaseHelper = new ChatHistoryDatabaseHelper(this);
usageDatabaseHelper = new UsageDatabaseHelper(this);
sp = getSharedPreferences("net.devemperor.wristassist", MODE_PRIVATE);

String apiKey = sp.getString("net.devemperor.wristassist.api_key", "noApiKey");
Expand Down Expand Up @@ -263,13 +268,16 @@ private void query(String query) throws JSONException, IOException {
.build();

thread = Executors.newSingleThreadExecutor();
String finalModel = model;
thread.execute(() -> {
try {
ChatCompletionResult result = service.createChatCompletion(ccr);
ChatMessage answer = result.getChoices().get(0).getMessage();
long cost = result.getUsage().getTotalTokens();
ChatItem assistantItem = new ChatItem(answer, cost);
sp.edit().putLong("net.devemperor.wristassist.total_tokens", sp.getLong("net.devemperor.wristassist.total_tokens", 0) + cost).apply();
Usage usage = result.getUsage();
ChatItem assistantItem = new ChatItem(answer, usage.getTotalTokens());

usageDatabaseHelper.edit(finalModel, usage.getTotalTokens(), Util.calcCost(finalModel, usage.getPromptTokens(), usage.getCompletionTokens()));

if (Thread.interrupted()) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ protected void onCreate(Bundle savedInstanceState) {
ArrayList<MainItem> menuItems = new ArrayList<>();
menuItems.add(new MainItem(R.drawable.twotone_add_24, getString(R.string.wristassist_menu_new_chat)));
menuItems.add(new MainItem(R.drawable.twotone_chat_24, getString(R.string.wristassist_menu_saved_chats)));
menuItems.add(new MainItem(R.drawable.twotone_insert_chart_outlined_24, getString(R.string.wristassist_menu_usage)));
menuItems.add(new MainItem(R.drawable.twotone_settings_24, getString(R.string.wristassist_menu_settings)));
menuItems.add(new MainItem(R.drawable.twotone_info_24, getString(R.string.wristassist_menu_about)));

Expand All @@ -73,9 +74,12 @@ protected void onCreate(Bundle savedInstanceState) {
intent = new Intent(this, SavedChatsActivity.class);
startActivity(intent);
} else if (menuPosition == 2) {
intent = new Intent(this, SettingsActivity.class);
intent = new Intent(this, UsageActivity.class);
startActivity(intent);
} else if (menuPosition == 3) {
intent = new Intent(this, SettingsActivity.class);
startActivity(intent);
} else if (menuPosition == 4) {
intent = new Intent(this, AboutActivity.class);
startActivity(intent);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package net.devemperor.wristassist.activities;

import android.os.Bundle;
import android.view.LayoutInflater;
import android.view.View;
import android.widget.Button;
import android.widget.ListView;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import net.devemperor.wristassist.R;
import net.devemperor.wristassist.adapters.UsageAdapter;
import net.devemperor.wristassist.database.UsageDatabaseHelper;

import java.util.Locale;

public class UsageActivity extends AppCompatActivity {

ListView usageLv;
Button resetUsageBtn;
TextView totalCostTv;

UsageDatabaseHelper usageDatabaseHelper;
UsageAdapter usageAdapter;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_usage);

usageDatabaseHelper = new UsageDatabaseHelper(this);
usageAdapter = new UsageAdapter(this, usageDatabaseHelper.getAll());
usageLv = findViewById(R.id.usage_lv);
usageLv.setAdapter(usageAdapter);

View footerView = LayoutInflater.from(this).inflate(R.layout.layout_usage_footer, usageLv, false);
resetUsageBtn = footerView.findViewById(R.id.usage_reset_btn);

totalCostTv = footerView.findViewById(R.id.usage_total_cost_tv);
totalCostTv.setText(getString(R.string.wristassist_total_cost,
String.format(Locale.getDefault(), "%,.2f", usageDatabaseHelper.getTotalCost())));

usageLv.addFooterView(footerView);

usageLv.requestFocus();

if (usageAdapter.getCount() == 0) {
noUsage();
}
}

public void resetUsage(View view) {
usageDatabaseHelper.reset();
usageAdapter.clear();
usageAdapter.addAll(usageDatabaseHelper.getAll());
usageAdapter.notifyDataSetChanged();

noUsage();
}

private void noUsage() {
totalCostTv.setText(getString(R.string.wristassist_no_usage_yet));
resetUsageBtn.setEnabled(false);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package net.devemperor.wristassist.adapters;

import android.content.Context;
import android.view.LayoutInflater;
import android.view.View;
import android.view.ViewGroup;
import android.widget.ArrayAdapter;
import android.widget.TextView;

import androidx.annotation.NonNull;

import net.devemperor.wristassist.R;
import net.devemperor.wristassist.database.UsageModel;
import net.devemperor.wristassist.util.Util;

import java.util.List;
import java.util.Locale;


public class UsageAdapter extends ArrayAdapter<UsageModel> {
final Context context;
final List<UsageModel> objects;


public UsageAdapter(@NonNull Context context, @NonNull List<UsageModel> objects) {
super(context, -1, objects);
this.context = context;
this.objects = objects;
}

@NonNull
@Override
public View getView (int position, View convertView, @NonNull ViewGroup parent) {
View listItem = LayoutInflater.from(context).inflate(R.layout.item_usage, parent, false);

UsageModel dataProvider = objects.get(position);

TextView modelNameTv = listItem.findViewById(R.id.usage_model_tv);
modelNameTv.setText(Util.translateModelNames(dataProvider.getModelName()));
modelNameTv.setTextSize(18 * Util.getFontMultiplier(context));

TextView tokensTv = listItem.findViewById(R.id.usage_tokens_tv);
tokensTv.setText(context.getString(R.string.wristassist_token_usage,
String.format(Locale.getDefault(), "%,d", dataProvider.getTokens())));
tokensTv.setTextSize(16 * Util.getFontMultiplier(context));

TextView costTv = listItem.findViewById(R.id.usage_cost_tv);
costTv.setText(context.getString(R.string.wristassist_estimated_cost,
String.format(Locale.getDefault(), "%,.2f", dataProvider.getCost())));
costTv.setTextSize(16 * Util.getFontMultiplier(context));

return listItem;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package net.devemperor.wristassist.database;

import android.content.ContentValues;
import android.content.Context;
import android.database.Cursor;
import android.database.sqlite.SQLiteDatabase;
import android.database.sqlite.SQLiteOpenHelper;

import androidx.annotation.Nullable;

import java.util.ArrayList;
import java.util.List;

public class UsageDatabaseHelper extends SQLiteOpenHelper {

Context context;

public UsageDatabaseHelper(@Nullable Context context) {
super(context, "usage.db", null, 1);
this.context = context;
}

@Override
public void onCreate(SQLiteDatabase db) {
db.execSQL("CREATE TABLE USAGE (MODEL_NAME TEXT PRIMARY KEY, TOKENS LONG, COST DOUBLE)");
}

@Override
public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) { }

public void edit(String model, long tokensToAdd, double costToAdd) {
SQLiteDatabase db = this.getWritableDatabase();
Cursor cursor = db.rawQuery("SELECT * FROM USAGE WHERE MODEL_NAME='" + model + "'", null);

boolean entryExists = cursor.moveToFirst();
cursor.close();

if (!entryExists) {
ContentValues cv = new ContentValues();
cv.put("MODEL_NAME", model);
cv.put("TOKENS", tokensToAdd);
cv.put("COST", costToAdd);
db.insert("USAGE", null, cv);
} else {
cursor = db.rawQuery("SELECT * FROM USAGE WHERE MODEL_NAME='" + model + "'", null);
if (cursor.moveToFirst()) {
long lastTokens = cursor.getLong(1);
double lastCost = cursor.getDouble(2);
ContentValues cv = new ContentValues();
cv.put("TOKENS", lastTokens + tokensToAdd);
cv.put("COST", lastCost + costToAdd);
db.update("USAGE", cv, "MODEL_NAME='" + model + "'", null);
}
cursor.close();
}

db.close();
}

public void reset() {
SQLiteDatabase db = this.getWritableDatabase();
db.execSQL("DELETE FROM USAGE");
db.close();
}

public List<UsageModel> getAll() {
SQLiteDatabase db = this.getWritableDatabase();
Cursor cursor = db.rawQuery("SELECT * FROM USAGE", null);

List<UsageModel> models = new ArrayList<>();
if (cursor.moveToFirst()) {
do {
models.add(new UsageModel(cursor.getString(0), cursor.getLong(1), cursor.getDouble(2)));
} while (cursor.moveToNext());
}
cursor.close();
db.close();
return models;
}

public double getTotalCost() {
SQLiteDatabase db = this.getWritableDatabase();
Cursor cursor = db.rawQuery("SELECT SUM(COST) FROM USAGE", null);

double totalCost = 0;
if (cursor.moveToFirst()) {
totalCost = cursor.getDouble(0);
}
cursor.close();
db.close();
return totalCost;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.devemperor.wristassist.database;

public class UsageModel {
private final String modelName;
private final long tokens;
private final double cost;

public UsageModel(String modelName, long tokens, double cost) {
this.modelName = modelName;
this.tokens = tokens;
this.cost = cost;
}

public String getModelName() {
return modelName;
}

public long getTokens() {
return tokens;
}

public double getCost() {
return cost;
}
}
Loading

0 comments on commit 4cd892d

Please sign in to comment.