OpenAI SDK开发(1)
本次完成的是基本框架的搭建,项目结构如下图所示:
common
Constants
common包下定义了Constants类,里面暂时写了一个枚举对象Role,是要用在Message中的一个参数,而Message在Request和Response中都有,所以放在common包下,后面会用到.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| public class Constants {
public enum Role {
SYSTEM("system"), USER("user"), ASSISTANT("assistant"), ;
private String code;
Role(String code) { this.code = code; }
public String getCode() { return code; }
}
}
|
domain
chat(聊天模型)
ChatChoice
1 2 3 4 5 6 7 8 9 10
| @Data public class ChatChoice implements Serializable {
private long index; @JsonProperty("message") private Message message; @JsonProperty("finish_reason") private String finishReason;
}
|
这里面定义的是choices中的几个参数,choices参数是在Response中的
ChatCompletionRequest(聊天完成请求)
把model单独写了一个枚举类,定义类所需参数
这些参数里只有model和message是必须的,其他的都是可选的,所以用Builder模式来构建对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| @Data @Builder @Slf4j @JsonInclude(JsonInclude.Include.NON_NULL) @NoArgsConstructor @AllArgsConstructor public class ChatCompletionRequest implements Serializable {
private String model = Model.GPT_3_5_TURBO.getCode(); private List<Message> messages; private double temperature = 0.2; @JsonProperty("top_p") private Double topP = 1d; private Integer n = 1; private boolean stream = false; private List<String> stop; @JsonProperty("max_tokens") private Integer maxTokens = 2048; @JsonProperty("frequency_penalty") private double frequencyPenalty = 0; @JsonProperty("presence_penalty") private double presencePenalty = 0; @JsonProperty("logit_bias") private Map logitBias; private String user;
@Getter @AllArgsConstructor public enum Model { GPT_3_5_TURBO("gpt-3.5-turbo"), GPT_4("gpt-4"), GPT_4_32K("gpt-4-32k"), ; private String code; }
}
|
ChatCompletionResponse(聊天完成响应)
定义了Response中的参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| @Data public class ChatCompletionResponse implements Serializable {
private String id; private String object; private String model; private List<ChatChoice> choices; private long created; private Usage usage;
}
|
Message
定义的聊天消息对象,包含消息角色、消息内容、消息名称
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
| @Data @JsonInclude(JsonInclude.Include.NON_NULL) public class Message implements Serializable {
private String role; private String content; private String name;
public Message() { }
private Message(Builder builder) { this.role = builder.role; this.content = builder.content; this.name = builder.name; }
public static Builder builder() { return new Builder(); }
public static final class Builder {
private String role; private String content; private String name;
public Builder() { }
public Builder role(Constants.Role role) { this.role = role.getCode(); return this; }
public Builder content(String content) { this.content = content; return this; }
public Builder name(String name) { this.name = name; return this; }
public Message build() { return new Message(this); } }
}
|
other
Usage(使用量)
是Response中的一个参数,记录了token的使用量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| public class Usage implements Serializable {
@JsonProperty("prompt_tokens") private long promptTokens; @JsonProperty("completion_tokens") private long completionTokens; @JsonProperty("total_tokens") private long totalTokens;
public long getPromptTokens() { return promptTokens; }
public void setPromptTokens(long promptTokens) { this.promptTokens = promptTokens; }
public long getCompletionTokens() { return completionTokens; }
public void setCompletionTokens(long completionTokens) { this.completionTokens = completionTokens; }
public long getTotalTokens() { return totalTokens; }
public void setTotalTokens(long totalTokens) { this.totalTokens = totalTokens; }
}
|
OpenAiResponse
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| @Data public class OpenAiResponse<T> implements Serializable {
private String object; private List<T> data; private Error error;
@Data public class Error { private String message; private String type; private String param; private String code; }
}
|
qa(问答模型)
很快就不能用了,而且跟聊天模型差不多,就不贴代码了
QAChoice
QACompletionRequest
model和prompt参数必要,其他参数可选
QACompletionResponse
interceptor
OpenAiInterceptor(自定义拦截器)
auth方法将token参数加入url对象,返回一个新的请求,intercept对该请求进行预处理,然后将处理后的请求传递给下一个拦截器(或目标方法)继续处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| public class OpenAiInterceptor implements Interceptor {
private String apiKey; private String authToken;
public OpenAiInterceptor(String apiKey, String authToken) { this.apiKey = apiKey; this.authToken = authToken; }
@NotNull @Override public Response intercept(Chain chain) throws IOException { return chain.proceed(this.auth(apiKey, chain.request())); }
private Request auth(String apiKey, Request original) { HttpUrl url = original.url().newBuilder() .addQueryParameter("token", authToken) .build();
return original.newBuilder() .url(url) .header(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey) .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue()) .method(original.method(), original.body()) .build(); }
}
|
session
IOpenAiApi
定义访问接口,传入请求
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| public interface IOpenAiApi {
@POST("v1/completions") Single<QACompletionResponse> completions(@Body QACompletionRequest qaCompletionRequest);
@POST("v1/chat/completions") Single<ChatCompletionResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);
}
|
OpenAiSession
会话接口
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| public interface OpenAiSession {
QACompletionResponse completions(QACompletionRequest qaCompletionRequest);
QACompletionResponse completions(String question);
ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest);
}
|
OpenAiSessionFactory
会话工厂接口
1 2 3 4 5
| public interface OpenAiSessionFactory {
OpenAiSession openSession();
}
|
Configuration
配置类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| @Slf4j @Data @NoArgsConstructor @AllArgsConstructor public class Configuration {
@Getter @NotNull private String apiKey;
@Getter private String apiHost;
@Getter
private String authToken;
}
|
DefaultOpenAiSession
实现OpenAiSession接口
blockingGet()是RxJava中Single中的方法,用于将当前线程阻塞,这里的作用是将异步计算的结果转换为同步结果,使得调用这个方法的线程会等待异步计算完成后才继续执行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| public class DefaultOpenAiSession implements OpenAiSession {
private IOpenAiApi openAiApi;
public DefaultOpenAiSession(IOpenAiApi openAiApi) { this.openAiApi = openAiApi; }
@Override public QACompletionResponse completions(QACompletionRequest qaCompletionRequest) { return this.openAiApi.completions(qaCompletionRequest).blockingGet(); }
@Override public QACompletionResponse completions(String question) { QACompletionRequest request = QACompletionRequest .builder() .prompt(question) .build(); Single<QACompletionResponse> completions = this.openAiApi.completions(request); return completions.blockingGet(); }
@Override public ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest) { return this.openAiApi.completions(chatCompletionRequest).blockingGet(); }
}
|
DefaultOpenAiSessionFactory
实现OpenAiSessionFactory接口,其实实现的是IOpenAiApi接口,返回一个DefaultOpenAiSession(openAiApi)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
private final Configuration configuration;
public DefaultOpenAiSessionFactory(Configuration configuration) { this.configuration = configuration; }
@Override public OpenAiSession openSession() { HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(); httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
OkHttpClient okHttpClient = new OkHttpClient .Builder() .addInterceptor(httpLoggingInterceptor) .addInterceptor(new OpenAiInterceptor(configuration.getApiKey(), configuration.getAuthToken())) .connectTimeout(450, TimeUnit.SECONDS) .writeTimeout(450, TimeUnit.SECONDS) .readTimeout(450, TimeUnit.SECONDS) .build();
IOpenAiApi openAiApi = new Retrofit.Builder() .baseUrl(configuration.getApiHost()) .client(okHttpClient)
.addCallAdapterFactory(RxJava2CallAdapterFactory.create()) .addConverterFactory(JacksonConverterFactory.create()) .build().create(IOpenAiApi.class);
return new DefaultOpenAiSession(openAiApi); }
}
|
单元测试
传url,key,token和request就行
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| @Slf4j public class ApiTest {
private OpenAiSession openAiSession;
@Before public void test_OpenAiSessionFactory() { Configuration configuration = new Configuration(); configuration.setApiHost("https://api.openai-proxy.com/"); configuration.setApiKey("xxx"); configuration.setAuthToken("xxx"); OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration); this.openAiSession = factory.openSession(); }
@Test public void test_chat_completions() { ChatCompletionRequest chatCompletion = ChatCompletionRequest .builder() .messages(Collections.singletonList(Message.builder().role(Constants.Role.USER).content("写一个java冒泡排序").build())) .model(ChatCompletionRequest.Model.GPT_3_5_TURBO.getCode()) .build(); ChatCompletionResponse chatCompletionResponse = openAiSession.completions(chatCompletion); chatCompletionResponse.getChoices().forEach(e -> { log.info("测试结果:{}", e.getMessage()); }); }
}
|